{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import math\n",
    "from itertools import combinations, combinations_with_replacement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using gpu: True \n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "print('Using gpu: %s ' % torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# [Thinking like Transformers](https://arxiv.org/abs/2106.06981)\n",
    "\n",
    "Here we code our 'toy' GPT without any training in order to compute histograms. For the input sequence `<BOS>,a,a,b,a,b,c`, the output should be `0,3,3,2,3,2,1` as the letter `a` appears 3 times, the letter `b` 2 times and the letter `c` once. Each letter is replaced by its number of occurences (except `<BOS>` replaced by a `0`). \n",
    "\n",
    "## Self-Attention\n",
    "\n",
    "First start by coding your Self-Attention layer (do not worry about initialization for the moment)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SelfAttentionLayer(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.n_channels = config.n_channels\n",
    "        self.key_channels = config.key_channels\n",
    "        self.Query = nn.Linear(self.n_channels, self.key_channels, bias=False)\n",
    "        self.Key = nn.Linear(self.n_channels, self.key_channels, bias = False)\n",
    "        self.Value = nn.Linear(self.n_channels, self.n_channels, bias = False)\n",
    "    \n",
    "    def _init_hist(self):\n",
    "        self.Query.weight.data = 100*torch.eye(self.key_channels, self.n_channels)\n",
    "        self.Key.weight.data = 100*torch.eye(self.key_channels,self.n_channels)\n",
    "        self.Value.weight.data = torch.zeros(self.n_channels,self.n_channels)\n",
    "        self.Value.weight.data[0,0] = 1.0\n",
    "        \n",
    "    def _init_id(self):\n",
    "        self.Query.weight.data = 100*torch.eye(self.key_channels, self.n_channels)\n",
    "        self.Key.weight.data = 100*torch.eye(self.key_channels,self.n_channels)\n",
    "        self.Value.weight.data = torch.eye(self.key_channels,self.n_channels)        \n",
    "        \n",
    "    def forward(self, x): # x (bs, T, ic)\n",
    "        Q = self.Query(x) # (bs, T, kc)\n",
    "        K = self.Key(x)/math.sqrt(self.key_channels) # (bs, T, kc)\n",
    "        V = self.Value(x) # (bs, T, oc)\n",
    "        A = torch.einsum('btk,bsk->bst', Q, K) # (bs, T, kc), (bs, T, kc) -> (bs , T, T)\n",
    "        A = F.softmax(A, dim=-1)\n",
    "        y = A @ V # (bs, T, oc)\n",
    "        return y, A"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check your implementation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class toy_config:\n",
    "    n_channels = 3\n",
    "    key_channels = 3\n",
    "    \n",
    "sa_toy = SelfAttentionLayer(toy_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "input = torch.randn(5,10,3)\n",
    "y,A = sa_toy(input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 10, 3])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,\n",
       "         1.0000],\n",
       "        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,\n",
       "         1.0000],\n",
       "        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,\n",
       "         1.0000],\n",
       "        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,\n",
       "         1.0000],\n",
       "        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,\n",
       "         1.0000]], grad_fn=<SumBackward1>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.sum(A, dim=-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## identity GPT\n",
    "\n",
    "We first start with a simple example where we want to contruct the identity map. Clearly, in this case, we can just use the skip connections present in real transformer block. Instead, we will ignore these skip connections and use the self-attention layer. In this practical, we will ignore the layer norm.\n",
    "\n",
    "To make our life simpler, we encode `<BOS>` with a `0`, letter `a` with a `1` and so on...\n",
    "\n",
    "If we give as input the sequence `0,1,1,2,3,4,2,3,1`, we want to get the same sequence as output. This is clearly doable with a transformer block as follows:\n",
    "- take one-hot encoding of each token \n",
    "- take Query and Key matrices as `100*Id`\n",
    "- take Value matrix as `Id`\n",
    "As a result, the output of the self-attention layer will be the same as the input.\n",
    "\n",
    "Then take a Feed Forward Network which is simply the identity map as coded below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Block_id(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.attn = SelfAttentionLayer(config)\n",
    "        self.fake_mlp = (lambda x : x)\n",
    "        self.attn._init_id()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x, A = self.attn(x)\n",
    "        x = self.fake_mlp(x)\n",
    "        return x, A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_digits = 4\n",
    "class config:\n",
    "    n_channels=nb_digits+1\n",
    "    key_channels=nb_digits+1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[[0., 0., 1., 0., 0.],\n",
       "          [0., 1., 0., 0., 0.]]], grad_fn=<UnsafeViewBackward0>),\n",
       " tensor([[[1., 0.],\n",
       "          [0., 1.]]], grad_fn=<SoftmaxBackward0>))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bid = Block_id(config)\n",
    "one_sample = torch.tensor([[0.,0.,1.,0.,0.],[0.,1.,0.,0.,0.]]).unsqueeze(0)\n",
    "bid(one_sample)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now to have really the identity map, we need to project back the one-hot encoding and this can be done with a linear layer (with good weights initialization)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPT_id(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.n_channels = config.n_channels\n",
    "        self.tok_emb = nn.Embedding(self.n_channels,self.n_channels)\n",
    "        self.block = Block_id(config)\n",
    "        self.head = nn.Linear(self.n_channels, 1, bias = False)\n",
    "        self._init_weights()\n",
    "        \n",
    "    def _init_weights(self):\n",
    "        self.tok_emb.weight.data = torch.eye(self.n_channels,self.n_channels)\n",
    "        self.head.weight.data = torch.arange(0,self.n_channels,dtype=torch.float32)\n",
    "        \n",
    "    def forward(self, idx):\n",
    "        x = self.tok_emb(idx)\n",
    "        x, A = self.block(x)\n",
    "        return self.head(x), A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "gid = GPT_id(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "one_sample = torch.tensor([0,1,1,2,3,4,2,3,1]).unsqueeze(0)\n",
    "y, A = gid(one_sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[True, True, True, True, True, True, True, True, True]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y == one_sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeMAAAGiCAYAAADUc67xAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAomElEQVR4nO3df3RU9ZnH8U8SzSRCZkQhCYHBIHZFfiuBbKBW2UZyKKZlz65FSks2tu7WDQrMadekFaKlMGArJ11BEBaEcyoF6y7WVcSD2QWWJZxAMD2wq1DLryw1AbZlBqJOdGb2D3VwlgRyMz/undz365x7Tuf23vk+zMg8PM/3e+9NC4fDYQEAANOkmx0AAAB2RzIGAMBkJGMAAExGMgYAwGQkYwAATEYyBgDAZCRjAABMRjIGAMBkJGMAAExGMgYAwGQkYwAAPrNnzx6Vl5eroKBAaWlpeuWVV655zq5du3TXXXfJ4XDotttu08aNGw2PSzIGAOAz7e3tGjt2rFatWtWt40+cOKHp06drypQpam5u1vz58/W9731Pb775pqFx03hQBAAAV0pLS9O2bds0Y8aMLo95/PHH9frrr+vIkSORfQ8++KAuXLigHTt2dHus62IJtCdCoZD+8Ic/KCcnR2lpackeHgAQg3A4rIsXL6qgoEDp6Ylrrn700Ufq6OiI+X3C4fAVucbhcMjhcMT83pLU0NCg0tLSqH1lZWWaP3++ofdJejL+wx/+ILfbnexhAQBx1NLSosGDByfkvT/66CMNHTpUra2tMb9X3759denSpah9tbW1evLJJ2N+b0lqbW1VXl5e1L68vDz5/X59+OGHys7O7tb7JD0Z5+TkSPr0i3Q6nckevtvyXS6zQwAAywlL+kiXf8sToaOjQ62trWppORFTnvD7/XK7h16Rb+JVFcdT0pPx5+0Cp9Np6WRMAx0AupaMacZ45YlE5pv8/Hy1tbVF7Wtra5PT6ex2VSyZkIwBAOieTz7bYjk/sUpKSrR9+/aofTt37lRJSYmh9+HSJgCARX0Sh82YS5cuqbm5Wc3NzZI+vXSpublZp0+fliTV1NRozpw5keO///3v6/jx4/qHf/gHvfvuu3ruuef00ksvacGCBYbGpTIGAFhU8ivjgwcPasqUKZHXHo9HklRRUaGNGzfq/fffjyRmSRo6dKhef/11LViwQL/4xS80ePBg/dM//ZPKysoMjZv064z9fr9cLpd8Pp+l54z7cNkVAFwhLOlDKaG/4ZfzxKmYF3C5XLdYPt9IVMYAAMsKKrbKOBivQBKOZAwAsCjrL+CKFxZwAQBgMipjAIBF2acyJhkDACzKPsmYNjUAACajMgYAWFRQsa2IZjU1AAAxss+lTbSpAQAwGZUxAMCiWMB1VatWrVJhYaGysrJUXFysxsbGeMcFALC95D8owiyGk/HWrVvl8XhUW1urQ4cOaezYsSorK9PZs2cTER8AwLZIxl1asWKFHn74YVVWVmrEiBFas2aNbrjhBm3YsKHT4wOBgPx+f9QGAAAuM5SMOzo61NTUpNLS0stvkJ6u0tJSNTQ0dHqO1+uVy+WKbG63O7aIAQA28flq6p5uvXQ19fnz5xUMBpWXlxe1Py8vT62trZ2eU1NTI5/PF9laWlp6Hi0AwEbs06ZO+Gpqh8Mhh8OR6GEAAEhZhpJx//79lZGRoba2tqj9bW1tys/Pj2tgAAC749KmTmVmZmr8+PGqr6+P7AuFQqqvr1dJSUncgwMA2Blt6i55PB5VVFSoqKhIEydOVF1dndrb21VZWZmI+AAA6PUMJ+OZM2fq3LlzWrRokVpbWzVu3Djt2LHjikVdAADExj5t6h4t4Jo7d67mzp0b71gAAPgCHhQBAACShAdFAAAsijY1AAAmIxkDAGAy+yRj5owBADAZlTEAwKLsUxmTjAEAFsWlTQAAIEmojAEAFhVUbNVt6lTGJGMAgEXZZ86YNjUAACajMgYAWJR9KmPTknG+y6U0swbvhhNmB9BLDDU7ACAFWfn356Kk25I2GqupAQBAktCmBgBYFG1qAABMRjIGAMBk9knGzBkDAGAyKmMAgEXZpzImGQMALIpLmwAAQJJQGQMALOoTSRkxnp8aSMYAAIuyTzKmTQ0AgMmojAEAFmWfyphkDACwKFZTAwCAJKEyBgBY1CeKrWakTQ0AQIxIxgAAmMw+ydjwn3LPnj0qLy9XQUGB0tLS9MorryQgLAAA7MNwMm5vb9fYsWO1atWqRMQDAMBngnHYUoPhNvW0adM0bdq0bh8fCAQUCAQir/1+v9EhAQC2xKVNceP1euVyuSKb2+1O9JAAAKSUhCfjmpoa+Xy+yNbS0pLoIQEAvcIncdhSQ8JXUzscDjkcjkQPAwDodT6RlBbj+amBO3ABAGAyrjMGAFiUfSpjw8n40qVLeu+99yKvT5w4oebmZt10000aMmRIXIMDANgZybhLBw8e1JQpUyKvPR6PJKmiokIbN26MW2AAANiF4WR87733KhwOJyIWAAC+IKjYKuPUuc6YOWMAgEXF2mZOnTY1q6kBABZlznXGq1atUmFhobKyslRcXKzGxsarHl9XV6fbb79d2dnZcrvdWrBggT766CNDY5KMAQD4zNatW+XxeFRbW6tDhw5p7NixKisr09mzZzs9fvPmzaqurlZtba3eeecdrV+/Xlu3btWPfvQjQ+OSjAEAFpX8ynjFihV6+OGHVVlZqREjRmjNmjW64YYbtGHDhk6P37dvnyZPnqxvfetbKiws1NSpUzVr1qxrVtP/H8kYAGBRnz8ooqfbpwu4/H5/1PbFhxd9UUdHh5qamlRaWhrZl56ertLSUjU0NHR6zqRJk9TU1BRJvsePH9f27dv1ta99zdCflGQMAOjV3G531AOLvF5vp8edP39ewWBQeXl5Ufvz8vLU2tra6Tnf+ta39JOf/ERf/vKXdf3112vYsGG69957DbepWU0NALCoTyTFcintp5VxS0uLnE5nZG88n5ewa9cuLV26VM8995yKi4v13nvvad68eVq8eLEWLlzY7fchGQMALCo+ydjpdEYl4670799fGRkZamtri9rf1tam/Pz8Ts9ZuHChvvOd7+h73/ueJGn06NFqb2/X3/7t3+rHP/6x0tO714CmTQ0AgKTMzEyNHz9e9fX1kX2hUEj19fUqKSnp9JwPPvjgioSbkZEhSYZukEVlDACwqPhUxkZ4PB5VVFSoqKhIEydOVF1dndrb21VZWSlJmjNnjgYNGhSZdy4vL9eKFSt05513RtrUCxcuVHl5eSQpdwfJGABgUclPxjNnztS5c+e0aNEitba2aty4cdqxY0dkUdfp06ejKuEnnnhCaWlpeuKJJ3TmzBkNGDBA5eXlWrJkiaFx08JJvtG03++Xy+VStmK742iinTA7gF5iqNkBACnIyr8/FyXdJsnn83VrHrYnPs8TPt8wOZ3dry6vfJ+gXK7fJzTWeKEyBgBYVFCxVcaheAWScCRjAIBFkYwBADDZJ4rtop/UScZc2gQAgMmojAEAFmWfyphkDACwKPskY9rUAACYjMoYAGBRQcVW3Sb1NhoxIRkDACzqE8V2e6jUSca0qQEAMBmVMQDAouxTGZOMAQAWZZ9kTJsaAACTURkDAKwpHIqtuE2dwphkDACwqJBiu7Ipde75QTIGAFhU8LMtlvNTBHPGAACYjMoYAGBNVMad83q9mjBhgnJycpSbm6sZM2bo6NGjiYoNAGBnoThsKcJQMt69e7eqqqq0f/9+7dy5Ux9//LGmTp2q9vb2RMUHAECvZ6hNvWPHjqjXGzduVG5urpqamvSVr3yl03MCgYACgUDktd/v70GYAADboU3dPT6fT5J00003dXmM1+uVy+WKbG63O5YhAQB2YaM2dVo4HO7RZdGhUEhf//rXdeHCBe3du7fL4zqrjN1ut7IV203OEu2E2QH0EkPNDgBIQVb+/bko6TZ9Wow5nc6EjOH3++VyueQ7JcUyhN8vuW5JbKzx0uPV1FVVVTpy5MhVE7EkORwOORyOng4DALCrkGJrNadQZdyjZDx37ly99tpr2rNnjwYPHhzvmAAAsNWcsaFkHA6H9eijj2rbtm3atWuXhg6lCQkAQKwMJeOqqipt3rxZv/nNb5STk6PW1lZJksvlUnZ2dkICBADYlI3uTW1oNfXq1avl8/l07733auDAgZFt69atiYoPAGBXwThsKcJwmxoAgKSw0ZwxD4oAAMBkPCgCAGBNNpozJhkDAKyJNjUAAEgWKmMAgDWFFVurOYXWHJOMAQDWRJsaAAAkC5UxAMCabFQZk4wBANZko0ubaFMDAGAyKmMAgDXRpgYAwGQkY/Ck5vhon212BNfW50WzI+g9+L7jw8q/P0m9dJc5YwAAkCxUxgAAawoptlZzClXGJGMAgDXRpgYAAMlCZQwAsCZWUwMAYDIbJWPa1AAAmIzKGABgTTZawEUyBgBYE21qAACQLFTGAABrslFlTDIGAFhTWLHN+yb1RtqxIRkDAKzJRpUxc8YAAJiMyhgAYE1c2gQAgMloUwMAgGQxlIxXr16tMWPGyOl0yul0qqSkRG+88UaiYgMA2FkwDluKMJSMBw8erGXLlqmpqUkHDx7UX/zFX+gb3/iG/uu//itR8QEA7CoUhy1FGJozLi8vj3q9ZMkSrV69Wvv379fIkSM7PScQCCgQCERe+/3+HoQJAEDv1eM542AwqC1btqi9vV0lJSVdHuf1euVyuSKb2+3u6ZAAADuhTd21w4cPq2/fvnI4HPr+97+vbdu2acSIEV0eX1NTI5/PF9laWlpiChgAYBMhxZaIe9imXrVqlQoLC5WVlaXi4mI1NjZe9fgLFy6oqqpKAwcOlMPh0J/92Z9p+/bthsY0fGnT7bffrubmZvl8Pr388suqqKjQ7t27u0zIDodDDofD6DAAALsz4TrjrVu3yuPxaM2aNSouLlZdXZ3Kysp09OhR5ebmXnF8R0eH7rvvPuXm5urll1/WoEGDdOrUKd14442GxjWcjDMzM3XbbbdJksaPH68DBw7oF7/4hZ5//nmjbwUAgKWsWLFCDz/8sCorKyVJa9as0euvv64NGzaourr6iuM3bNigP/7xj9q3b5+uv/56SVJhYaHhcWO+zjgUCkUt0AIAIC7iNGfs9/ujtq5yVkdHh5qamlRaWhrZl56ertLSUjU0NHR6zquvvqqSkhJVVVUpLy9Po0aN0tKlSxUMGpuwNpSMa2pqtGfPHp08eVKHDx9WTU2Ndu3apdmzZxsaFACAa4rTpU1utztqIbHX6+10uPPnzysYDCovLy9qf15enlpbWzs95/jx43r55ZcVDAa1fft2LVy4UM8884x++tOfGvqjGmpTnz17VnPmzNH7778vl8ulMWPG6M0339R9991naFAAAJKlpaVFTqcz8jqe65hCoZByc3O1du1aZWRkaPz48Tpz5ox+9rOfqba2ttvvYygZr1+/3nCgAAD0SJzuTf35XSOvpX///srIyFBbW1vU/ra2NuXn53d6zsCBA3X99dcrIyMjsu+OO+5Qa2urOjo6lJmZ2a1QuTc1AMCaknydcWZmpsaPH6/6+vrIvlAopPr6+i7vpzF58mS99957CoUuL90+duyYBg4c2O1ELJGMAQCI8Hg8WrdunTZt2qR33nlHjzzyiNrb2yOrq+fMmaOamprI8Y888oj++Mc/at68eTp27Jhef/11LV26VFVVVYbG5RGKAABrMuE645kzZ+rcuXNatGiRWltbNW7cOO3YsSOyqOv06dNKT79cx7rdbr355ptasGCBxowZo0GDBmnevHl6/PHHDY2bFg6Hw8bD7Tm/3y+Xy6VsSWnJHBimaE+BhfZ9XjQ7gt6D77v3C0v6UJLP5+vWPGxPfJ4nfMskZ1YM7/OR5KpObKzxQpsaAACT0aYGAFiTCW1qs5CMAQDWFKdLm1IByRgAYE02SsbMGQMAYDIqYwCANTFnDACAyWhTAwCAZKEyRkKlwg0WUuFGFVJqfJapEGMqfN+p8DkmhY0qY5IxAMCawopt3jep95eMDW1qAABMRmUMALAm2tQAAJjMRpc20aYGAMBkVMYAAGuiTQ0AgMlIxgAAmIw5YwAAkCxUxgAAa6JNDQCAyUKKLaHSpgYAAN1FZQwAsCYbLeAiGQMArMlGc8a0qQEAMBmVMQDAmmhTAwBgMtrU3bNs2TKlpaVp/vz5cQoHAAD76XFlfODAAT3//PMaM2ZMPOMBAOBTVMZXd+nSJc2ePVvr1q1Tv379rnpsIBCQ3++P2gAAuKZQHLYU0aNkXFVVpenTp6u0tPSax3q9Xrlcrsjmdrt7MiQAwG4+vwNXT7fenIy3bNmiQ4cOyev1duv4mpoa+Xy+yNbS0mI4SAAAejNDc8YtLS2aN2+edu7cqaysrG6d43A45HA4ehQcAMDGgoptmXEKzRkbSsZNTU06e/as7rrrrsi+YDCoPXv2aOXKlQoEAsrIyIh7kAAAG+I648599atf1eHDh6P2VVZWavjw4Xr88cdJxAAA9IChZJyTk6NRo0ZF7evTp49uvvnmK/YDABAT2tQAAJiMNnX37dq1Kw5hAABgX1TGAABrok0NAIDJbJSMeZ4xAAAmozIGAFhTWLEtwgrHK5DEIxkDAKwpKCktxvNTBMkYAGBNNkrGzBkDAGAyKmMAgDVx0w8AAExGmxoAACQLlTEAwJpoUwMAYDIbtalJxrC9Pi+aHUH3tIetfweDPmmx/HImR6p837AXkjEAwJpCiq26pU0NAECMQoqtTZ1CyZjV1AAAmIzKGABgTbEuwGIBFwAAMSIZAwBgMuaMAQBAslAZAwCsiTY1AAAmo00NAACShcoYAGBNsVa2KVQZk4wBANYUlBTLLdlTKBnTpgYAwGQkYwCANYXisPXAqlWrVFhYqKysLBUXF6uxsbFb523ZskVpaWmaMWOG4TFJxgAAawrGYTNo69at8ng8qq2t1aFDhzR27FiVlZXp7NmzVz3v5MmT+sEPfqC7777b+KAiGQMAejm/3x+1BQKBLo9dsWKFHn74YVVWVmrEiBFas2aNbrjhBm3YsKHLc4LBoGbPnq2nnnpKt956a49iJBkDAKwpTpWx2+2Wy+WKbF6vt9PhOjo61NTUpNLS0si+9PR0lZaWqqGhocswf/KTnyg3N1ff/e53e/xHNbSa+sknn9RTTz0Vte/222/Xu+++2+MAAADoVJwubWppaZHT6YzsdjgcnR5+/vx5BYNB5eXlRe3Py8vrMs/t3btX69evV3Nzc0yhGr60aeTIkXrrrbcuv8F1XB0FAEiAkGK7tOmzc51OZ1QyjpeLFy/qO9/5jtatW6f+/fvH9F6GM+l1112n/Pz8mAYFAMBq+vfvr4yMDLW1tUXtb2tr6zTv/f73v9fJkydVXl4e2RcKfVqOX3fddTp69KiGDRvWrbENzxn/7ne/U0FBgW699VbNnj1bp0+fvurxgUDgislzAACuKcmXNmVmZmr8+PGqr6+/HEIopPr6epWUlFxx/PDhw3X48GE1NzdHtq9//euaMmWKmpub5Xa7uz22ocq4uLhYGzdu1O233673339fTz31lO6++24dOXJEOTk5nZ7j9XqvmGcGAOCagortQRE9aHF7PB5VVFSoqKhIEydOVF1dndrb21VZWSlJmjNnjgYNGiSv16usrCyNGjUq6vwbb7xRkq7Yfy2GkvG0adMi/3vMmDEqLi7WLbfcopdeeqnLVWQ1NTXyeDyR136/39C/FgAASJaZM2fq3LlzWrRokVpbWzVu3Djt2LEjsqjr9OnTSk+P/4VIaeFwOJbpcU2YMEGlpaVdLhX///x+v1wul7IV2z94ALtpj+2valL0SeNvdW8XlvShJJ/Pl5BFUdLlPOHLlpwx/CflD0uuDxMba7zElN4vXbqk3//+9xo4cGC84gEA4FMm3Q7TDIaS8Q9+8APt3r1bJ0+e1L59+/SXf/mXysjI0KxZsxIVHwAAvZ6hOeP/+Z//0axZs/S///u/GjBggL785S9r//79GjBgQKLiAwDYlQkLuMxiKBlv2bIlUXEAABDNRsmYe1MDAGAy7mUJALCmsFKquo0FyRgAYEk9fCRx1PmpgmQMALAkOyVj5owBADAZlTEAwJJivW9HCt3zg2QMALAm2tQAACBpqIwBAJZEmxoAAJPRpgYAAElDZQwAsKSQYqtuaVMDABAj5oyBOGmfbXYE19bnRbMj6J4+abE8viY5+L6BniEZAwAsyU4LuEjGAABLIhkDAGAyO80Zc2kTAAAmozIGAFgSbWoAAExGmxoAACQNlTEAwJK4AxcAACaz05wxbWoAAExGZQwAsCQ7LeAiGQMALIk2NQAASBoqYwCAJdmpMiYZAwAsiTljAABMZqfK2PCc8ZkzZ/Ttb39bN998s7KzszV69GgdPHgwEbEBAGALhirjP/3pT5o8ebKmTJmiN954QwMGDNDvfvc79evXL1HxAQBsKqzYWs3heAWSBIaS8fLly+V2u/XCCy9E9g0dOjTuQQEAQJu6C6+++qqKior0wAMPKDc3V3feeafWrVt31XMCgYD8fn/UBgAALjOUjI8fP67Vq1frS1/6kt5880098sgjeuyxx7Rp06Yuz/F6vXK5XJHN7XbHHDQAoPcLxmFLFWnhcLjbbfXMzEwVFRVp3759kX2PPfaYDhw4oIaGhk7PCQQCCgQCkdd+v19ut1vZktJ6HjdSRPtssyO4tj4vmh1B78H33fuFJX0oyefzyel0JmQMv98vl8ul1yX1ieF92iVNV2JjjRdDlfHAgQM1YsSIqH133HGHTp8+3eU5DodDTqczagMAAJcZWsA1efJkHT16NGrfsWPHdMstt8Q1KAAA7LSAy1AyXrBggSZNmqSlS5fqm9/8phobG7V27VqtXbs2UfEBAGzKTsnYUJt6woQJ2rZtm371q19p1KhRWrx4serq6jR7dgpMFAEAYFGGb4d5//336/77709ELAAARHBvagAATBZSbK1mkjEAADGyU2Vs+EERAAAgvqiMAQCWZKfV1CRjAIAl2SkZ06YGAMBkVMYAAEuy0wIukjEAwJJoUwMAgKShMgYAWJKdKmOSMQDAksKKbd43HK9AkoA2NQAAJqMyBgBYEm1qIE76vGh2BNfWniJPAE2FzzIVYkyF7zsVPsdk4NImAABMZqfKmDljAABMRmUMALAkO1XGJGMAgCXZac6YNjUAAF+watUqFRYWKisrS8XFxWpsbOzy2HXr1unuu+9Wv3791K9fP5WWll71+K6QjAEAlhSMw2bU1q1b5fF4VFtbq0OHDmns2LEqKyvT2bNnOz1+165dmjVrlv793/9dDQ0Ncrvdmjp1qs6cOWNo3LRwOJzUm5T4/X65XC5lS0pL5sBAF1LhUheJy13iJRW+byt/12FJH0ry+XxyOp0JGePzPLFMUlYM7/ORpGpJLS0tUbE6HA45HI5OzykuLtaECRO0cuVKSVIoFJLb7dajjz6q6urqa44ZDAbVr18/rVy5UnPmzOl2rFTGAIBeze12y+VyRTav19vpcR0dHWpqalJpaWlkX3p6ukpLS9XQ0NCtsT744AN9/PHHuummmwzFyAIuAIAlxWsBV2eVcWfOnz+vYDCovLy8qP15eXl69913uzXm448/roKCgqiE3h0kYwCAJcXr0ian05mwlvoXLVu2TFu2bNGuXbuUlWWswU4yBgBAUv/+/ZWRkaG2trao/W1tbcrPz7/quT//+c+1bNkyvfXWWxozZozhsZkzBgBYUigOmxGZmZkaP3686uvrL8cQCqm+vl4lJSVdnvf0009r8eLF2rFjh4qKigyO+ikqYwCAJZlxBy6Px6OKigoVFRVp4sSJqqurU3t7uyorKyVJc+bM0aBBgyKLwJYvX65FixZp8+bNKiwsVGtrqySpb9++6tu3b7fHJRkDACzJjGQ8c+ZMnTt3TosWLVJra6vGjRunHTt2RBZ1nT59Wunpl5vKq1evVkdHh/76r/866n1qa2v15JNPdntckjEAAF8wd+5czZ07t9P/b9euXVGvT548GZcxScYAAEvi3tRdKCwsVFpa2hVbVVVVouIDANhUSLHdCjOVkrGhyvjAgQMKBi934Y8cOaL77rtPDzzwQNwDAwDALgwl4wEDBkS9XrZsmYYNG6Z77rknrkEBAMDzjLuho6NDv/zlL+XxeJSW1vUjHwKBgAKBQOS13+/v6ZAAABthzrgbXnnlFV24cEF/8zd/c9XjvF5v1A263W53T4cEAKBX6nEyXr9+vaZNm6aCgoKrHldTUyOfzxfZWlpaejokAMBGzHiesVl61KY+deqU3nrrLf3Lv/zLNY+92nMjAQDoCm3qa3jhhReUm5ur6dOnxzseAABsx3BlHAqF9MILL6iiokLXXcc9QwAAicFq6qt46623dPr0aT300EOJiAcAAEkk46uaOnWqwuFwImIBACAirNjmfVMpU/E8YwAATMakLwDAkmhTAwBgMjslY9rUAACYjMoYAGBJdrrpB8kYAGBJtKkBAEDSUBkDACyJNjUAACajTQ0AAJKGyhgAYEkhxVbd0qYGACBGzBkDAGCyoGKbS02lOWOScRdOmB1ALzHU7AC6oc+LZkeAZEqF79vKvz8XJd1mdhC9EMkYAGBJVMYAAJjMTnPGXNoEAIDJqIwBAJZEmxoAAJPRpgYAAElDZQwAsCTuwAUAgMmCktJiPD9V0KYGAMBkVMYAAEuy0wIukjEAwJLs1KYmGQMALMlOyZg5YwAATEZlDACwJOaMAQAwGW1qAACQNIaScTAY1MKFCzV06FBlZ2dr2LBhWrx4scLhcKLiAwDYVFiXW9U92VIpMxlqUy9fvlyrV6/Wpk2bNHLkSB08eFCVlZVyuVx67LHHEhUjAMCGYm0zp1Kb2lAy3rdvn77xjW9o+vTpkqTCwkL96le/UmNjY0KCAwDADgy1qSdNmqT6+nodO3ZMkvTb3/5We/fu1bRp07o8JxAIyO/3R20AAFxLMA5bqjBUGVdXV8vv92v48OHKyMhQMBjUkiVLNHv27C7P8Xq9euqpp2IOFABgLyHFtpo6lS5tMlQZv/TSS3rxxRe1efNmHTp0SJs2bdLPf/5zbdq0qctzampq5PP5IltLS0vMQQMA0JsYqox/+MMfqrq6Wg8++KAkafTo0Tp16pS8Xq8qKio6PcfhcMjhcMQeKQDAVljA1YUPPvhA6enRxXRGRoZCoVRqBgAAUgHJuAvl5eVasmSJhgwZopEjR+rtt9/WihUr9NBDDyUqPgCATdlpzthQMn722We1cOFC/f3f/73Onj2rgoIC/d3f/Z0WLVqUqPgAAOj10sJJvn2W3++Xy+VStmL7F0+inTA7gF5iqNkBACnIyr8/FyXdJsnn88npdCZkjM/zRKFiu2dzSNJJJTbWeOFBEQAAS4q1zZxKbWoeFAEAgMmojAEAlhRUbA97SKXKmGQMALAkOyVj2tQAAJiMyhgAYEl2WsBFMgYAWBJtagAAkDRUxgAASwoptso4qXe0ihHJGABgSbHemzqVkjFtagCAJQXjsPXEqlWrVFhYqKysLBUXF6uxsfGqx//617/W8OHDlZWVpdGjR2v79u2GxyQZAwDwma1bt8rj8ai2tlaHDh3S2LFjVVZWprNnz3Z6/L59+zRr1ix997vf1dtvv60ZM2ZoxowZOnLkiKFxk/6gCJ/PpxtvvFFZsvaDIg6bHUAvMdrsAIAUZOXfn4uS7pR04cIFuVyuhIwRrwcKhSV9KKmlpSXqQREOh0MOh6PTc4qLizVhwgStXLlSkhQKheR2u/Xoo4+qurr6iuNnzpyp9vZ2vfbaa5F9f/7nf65x48ZpzZo13Y416XPGFy9elCR9lOyBDbrN7AAA2FYq/P5cvHgxYck4MzNT+fn5am1tjfm9+vbtK7fbHbWvtrZWTz755BXHdnR0qKmpSTU1NZF96enpKi0tVUNDQ6fv39DQII/HE7WvrKxMr7zyiqE4k56MCwoK1NLSopycHKWlxV4b+/1+ud3uK/7lA2P4HOODzzF++CzjI96fYzgc1sWLF1VQUBCH6DqXlZWlEydOqKOjI+b3CofDV+Sarqri8+fPKxgMKi8vL2p/Xl6e3n333U7PaW1t7fR4o/+QSHoyTk9P1+DBg+P+vk6nk7+wccDnGB98jvHDZxkf8fwcE1URf1FWVpaysrISPo5VsIALAABJ/fv3V0ZGhtra2qL2t7W1KT8/v9Nz8vPzDR3fFZIxAAD6dK56/Pjxqq+vj+wLhUKqr69XSUlJp+eUlJREHS9JO3fu7PL4rqT8TT8cDodqa2u7nANA9/A5xgefY/zwWcYHn6MxHo9HFRUVKioq0sSJE1VXV6f29nZVVlZKkubMmaNBgwbJ6/VKkubNm6d77rlHzzzzjKZPn64tW7bo4MGDWrt2raFxk35pEwAAVrZy5Ur97Gc/U2trq8aNG6d//Md/VHFxsSTp3nvvVWFhoTZu3Bg5/te//rWeeOIJnTx5Ul/60pf09NNP62tf+5qhMUnGAACYjDljAABMRjIGAMBkJGMAAExGMgYAwGQpn4yNPuoK0bxeryZMmKCcnBzl5uZqxowZOnr0qNlhpbxly5YpLS1N8+fPNzuUlHPmzBl9+9vf1s0336zs7GyNHj1aBw8eNDuslBIMBrVw4UINHTpU2dnZGjZsmBYvXizW61pXSidjo4+6wpV2796tqqoq7d+/Xzt37tTHH3+sqVOnqr293ezQUtaBAwf0/PPPa8yYMWaHknL+9Kc/afLkybr++uv1xhtv6L//+7/1zDPPqF+/fmaHllKWL1+u1atXa+XKlXrnnXe0fPlyPf3003r22WfNDg1dSOlLm4w+6grXdu7cOeXm5mr37t36yle+YnY4KefSpUu666679Nxzz+mnP/2pxo0bp7q6OrPDShnV1dX6z//8T/3Hf/yH2aGktPvvv195eXlav359ZN9f/dVfKTs7W7/85S9NjAxdSdnK+PNHXZWWlkb2XetRV7g2n88nSbrppptMjiQ1VVVVafr06VH/XaL7Xn31VRUVFemBBx5Qbm6u7rzzTq1bt87ssFLOpEmTVF9fr2PHjkmSfvvb32rv3r2aNm2ayZGhKyl7O8yePOoKVxcKhTR//nxNnjxZo0aNMjuclLNlyxYdOnRIBw4cMDuUlHX8+HGtXr1aHo9HP/rRj3TgwAE99thjyszMVEVFhdnhpYzq6mr5/X4NHz5cGRkZCgaDWrJkiWbPnm12aOhCyiZjxF9VVZWOHDmivXv3mh1KymlpadG8efO0c+dOWz32Ld5CoZCKioq0dOlSSdKdd96pI0eOaM2aNSRjA1566SW9+OKL2rx5s0aOHKnm5mbNnz9fBQUFfI4WlbLJuCePukLX5s6dq9dee0179uxJyPOme7umpiadPXtWd911V2RfMBjUnj17tHLlSgUCAWVkZJgYYWoYOHCgRowYEbXvjjvu0D//8z+bFFFq+uEPf6jq6mo9+OCDkqTRo0fr1KlT8nq9JGOLStk545486gpXCofDmjt3rrZt26Z/+7d/09ChQ80OKSV99atf1eHDh9Xc3BzZioqKNHv2bDU3N5OIu2ny5MlXXFp37Ngx3XLLLSZFlJo++OADpadH/7xnZGQoFAqZFBGuJWUrY+naj7rCtVVVVWnz5s36zW9+o5ycHLW2tkqSXC6XsrOzTY4udeTk5Fwxz96nTx/dfPPNzL8bsGDBAk2aNElLly7VN7/5TTU2Nmrt2rWGH0dnd+Xl5VqyZImGDBmikSNH6u2339aKFSv00EMPmR0auhJOcc8++2x4yJAh4czMzPDEiRPD+/fvNzuklCKp0+2FF14wO7SUd88994TnzZtndhgp51//9V/Do0aNCjscjvDw4cPDa9euNTuklOP3+8Pz5s0LDxkyJJyVlRW+9dZbwz/+8Y/DgUDA7NDQhZS+zhgAgN4gZeeMAQDoLUjGAACYjGQMAIDJSMYAAJiMZAwAgMlIxgAAmIxkDACAyUjGAACYjGQMAIDJSMYAAJiMZAwAgMn+D9n2LSlEmFVLAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 640x480 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(A[0,:,:].cpu().data, cmap='hot', interpolation='nearest')\n",
    "plt.colorbar()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## histogram GPT\n",
    "\n",
    "Now we need to adapt previous case to code our 'toy' transformer block and your 'toy' GPT to compute histograms:\n",
    "- you will need to find a good initialization for the Quey, Key and Value matrices\n",
    "- for the feed forward network, you can fake the mlp with any function you'd like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Block_hist(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.attn = SelfAttentionLayer(config)\n",
    "        self.fake_mlp = (lambda x :1/x.sum(-1, keepdim=True)-1)\n",
    "        self.attn._init_hist()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x, A = self.attn(x)\n",
    "        x = self.fake_mlp(x)\n",
    "        return x, A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPT_hist(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.n_channels = config.n_channels\n",
    "        self.tok_emb = nn.Embedding(self.n_channels,self.n_channels)\n",
    "        self.block = Block_hist(config)\n",
    "        self._init_weights()\n",
    "        \n",
    "    def _init_weights(self):\n",
    "        self.tok_emb.weight.data = torch.eye(self.n_channels,self.n_channels)\n",
    "        self.tok_emb.weight.data[0,:] = torch.ones(self.n_channels)\n",
    "        \n",
    "    def forward(self, idx):\n",
    "        x = self.tok_emb(idx)\n",
    "        x, A = self.block(x)\n",
    "        return x, A"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check your implementation by first choosing properly your configuration:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "gh = GPT_hist(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.],\n",
       "         [3.],\n",
       "         [3.],\n",
       "         [2.],\n",
       "         [2.],\n",
       "         [1.],\n",
       "         [2.],\n",
       "         [2.],\n",
       "         [3.]]], grad_fn=<SubBackward0>)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "one_sample = torch.tensor([0,1,1,2,3,4,2,3,1]).unsqueeze(0)\n",
    "y, A = gh(one_sample)\n",
    "y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 9, 1])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeMAAAGiCAYAAADUc67xAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAoYUlEQVR4nO3df3RU9Z3/8VcSzSRKMqKQhMBgItuK/FYC2UCtso1wKKZlz64ipU02tu7WDQrktGtShWgpDGhlsysIwiJwTqVg3cW6ingwu8CyhBMIpgd2BYqAzFIngd2akaATnZnvH+rgfEkgN/Pj3sl9Ps75nNPc3juftxPh7fv9+dx7U0KhUEgAAMA0qWYHAACA3ZGMAQAwGckYAACTkYwBADAZyRgAAJORjAEAMBnJGAAAk5GMAQAwGckYAACTkYwBADAZyRgAgC/s2bNHZWVlys/PV0pKil599dWrXrNr1y7dcccdcjgc+pM/+RNt3LjR8LwkYwAAvtDR0aGxY8dq1apVPTr/1KlTmjFjhqZMmaKWlhbNnz9fP/rRj/TWW28ZmjeFF0UAAHC5lJQUbdu2TTNnzuz2nMcee0xvvPGGjhw5Ej72wAMP6MMPP9SOHTt6PNc10QTaG8FgUH/4wx+UlZWllJSURE8PAIhCKBTSRx99pPz8fKWmxq+5+sknn6izszPqzwmFQpflGofDIYfDEfVnS1JjY6NKS0sjjk2bNk3z58839DkJT8Z/+MMf5HK5Ej0tACCGPB6PhgwZEpfP/uSTT1RYWCiv1xv1Z/Xr108XLlyIOFZXV6cnn3wy6s+WJK/Xq9zc3Ihjubm58vl8+vjjj5WZmdmjz0l4Ms7KypL0+S8yOzs70dP3WJ7TaXYIAGA5IUmf6NLf5fHQ2dkpr9crj+dUVHnC5/PJ5Sq8LN/EqiqOpYQn4y/bBdnZ2ZZOxjTQAaB7iVhmjFWeiGe+ycvLU2tra8Sx1tZWZWdn97gqlkxIxgAA9MxnX4xoro+vkpISbd++PeLYzp07VVJSYuhzuLUJAGBRn8VgGHPhwgW1tLSopaVF0ue3LrW0tOjMmTOSpNraWpWXl4fP//GPf6yTJ0/q7/7u73T06FE9//zzevnll7VgwQJD81IZAwAsKvGV8cGDBzVlypTwz9XV1ZKkiooKbdy4UR988EE4MUtSYWGh3njjDS1YsED/8A//oCFDhuif/umfNG3aNEPzJvw+Y5/PJ6fTqfb2dkuvGV/PbVcAcJmQpI+luP4dfilPvB/1Bi6n82bL5xuJyhgAYFkBRVcZB2IVSNyRjAEAFmX9DVyxwgYuAABMRmUMALAo+1TGJGMAgEXZJxnTpgYAwGRUxgAAiwoouh3R7KYGACBK9rm1iTY1AAAmozIGAFgUG7iuaNWqVSooKFBGRoaKi4vV1NQU67gAALaX+BdFmMVwMt66dauqq6tVV1enQ4cOaezYsZo2bZra2triER8AwLZIxt1asWKFHnroIVVWVmrEiBFas2aNrrvuOr344otdnu/3++Xz+SIGAAC4xFAy7uzsVHNzs0pLSy99QGqqSktL1djY2OU1brdbTqczPFwuV3QRAwBs4svd1L0dfXQ39fnz5xUIBJSbmxtxPDc3V16vt8tramtr1d7eHh4ej6f30QIAbMQ+beq476Z2OBxyOBzxngYAgKRlKBkPGDBAaWlpam1tjTje2tqqvLy8mAYGALA7bm3qUnp6usaPH6+GhobwsWAwqIaGBpWUlMQ8OACAndGm7lZ1dbUqKipUVFSkiRMnqr6+Xh0dHaqsrIxHfAAA9HmGk/GsWbN07tw5LVq0SF6vV+PGjdOOHTsu29QFAEB07NOm7tUGrrlz52ru3LmxjgUAgK/gRREAACBBeFEEAMCiaFMDAGAykjEAACazTzJmzRgAAJNRGQMALMo+lTHJGABgUdzaBAAAEoTKGABgUQFFV90mT2VMMgYAWJR91oxpUwMAYDIqYwCARdmnMjYtGb/pdOo6sybvgS1mB9BHPGB2AEASsvLfPxeVyD/X7KYGAAAJQpsaAGBRtKkBADAZyRgAAJPZJxmzZgwAgMmojAEAFmWfyphkDACwKG5tAgAACUJlDACwqM8kpUV5fXIgGQMALMo+yZg2NQAAJqMyBgBYlH0qY5IxAMCi2E0NAAAShMoYAGBRnym6mpE2NQAAUSIZAwBgMvskY8P/lHv27FFZWZny8/OVkpKiV199NQ5hAQBgH4aTcUdHh8aOHatVq1bFIx4AAL4QiMFIDobb1NOnT9f06dN7fL7f75ff7w//7PP5jE4JALAlbm2KGbfbLafTGR4ulyveUwIAkFTinoxra2vV3t4eHh6PJ95TAgD6hM9iMJJD3HdTOxwOORyOeE8DAOhzPpOUEuX1yYEncAEAYDLuMwYAWJR9KmPDyfjChQs6ceJE+OdTp06ppaVFN954o4YOHRrT4AAAdkYy7tbBgwc1ZcqU8M/V1dWSpIqKCm3cuDFmgQEAYBeGk/Hdd9+tUCgUj1gAAPiKgKKrjJPnPmPWjAEAFhVtmzl52tTspgYAWJQ59xmvWrVKBQUFysjIUHFxsZqamq54fn19vW699VZlZmbK5XJpwYIF+uSTTwzNSTIGAOALW7duVXV1terq6nTo0CGNHTtW06ZNU1tbW5fnb968WTU1Naqrq9O7776r9evXa+vWrfrZz35maF6SMQDAohJfGa9YsUIPPfSQKisrNWLECK1Zs0bXXXedXnzxxS7P37dvnyZPnqzvfe97Kigo0NSpUzV79uyrVtP/P5IxAMCivnxRRG/H5xu4fD5fxPjqy4u+qrOzU83NzSotLQ0fS01NVWlpqRobG7u8ZtKkSWpubg4n35MnT2r79u369re/beiflGQMAOjTXC5XxAuL3G53l+edP39egUBAubm5Ecdzc3Pl9Xq7vOZ73/uefv7zn+sb3/iGrr32Wg0bNkx333234TY1u6kBABb1maRobqX9vDL2eDzKzs4OH43l+xJ27dqlpUuX6vnnn1dxcbFOnDihefPmafHixVq4cGGPP4dkDACwqNgk4+zs7Ihk3J0BAwYoLS1Nra2tEcdbW1uVl5fX5TULFy7UD37wA/3oRz+SJI0ePVodHR3667/+az3++ONKTe1ZA5o2NQAAktLT0zV+/Hg1NDSEjwWDQTU0NKikpKTLay5evHhZwk1LS5MkQw/IojIGAFhUbCpjI6qrq1VRUaGioiJNnDhR9fX16ujoUGVlpSSpvLxcgwcPDq87l5WVacWKFbr99tvDbeqFCxeqrKwsnJR7gmQMALCoxCfjWbNm6dy5c1q0aJG8Xq/GjRunHTt2hDd1nTlzJqISfuKJJ5SSkqInnnhCZ8+e1cCBA1VWVqYlS5YYmjcllOAHTft8PjmdTm2RdF0iJ4YpHjA7ACAJbTE7gCu4qM//XLe3t/doHbY3vswT7e3DlJ3d8+ry8s8JyOl8L66xxgqVMQDAogKKrjIOxiqQuCMZAwAsimQMAIDJPlN0N/0kTzLm1iYAAExGZQwAsCj7VMYkYwCARdknGdOmBgDAZFTGAACLCii66jahj9GICskYAGBRn0lKieL65EnGtKkBADAZlTEAwKLsUxmTjAEAFmWfZEybGgAAk1EZAwCsKRSMrrhNnsKYZAwAsKigoruzKXme+UEyBgBYVOCLEc31SYI1YwAATEZlDACwJirjrrndbk2YMEFZWVnKycnRzJkzdezYsXjFBgCws2AMRpIwlIx3796tqqoq7d+/Xzt37tSnn36qqVOnqqOjI17xAQDQ5xlqU+/YsSPi540bNyonJ0fNzc365je/2eU1fr9ffr8//LPP5+tFmAAA26FN3TPt7e2SpBtvvLHbc9xut5xOZ3i4XK5opgQA2AVt6qsLBoOaP3++Jk+erFGjRnV7Xm1trdrb28PD4/H0dkoAAPqkXu+mrqqq0pEjR7R3794rnudwOORwOHo7DQDAroKKrtWcRJVxr5Lx3Llz9frrr2vPnj0aMmRIrGMCAMBWa8aGknEoFNIjjzyibdu2adeuXSosLIxXXAAA2IahZFxVVaXNmzfrt7/9rbKysuT1eiVJTqdTmZmZcQkQAGBTNno2taENXKtXr1Z7e7vuvvtuDRo0KDy2bt0ar/gAAHYViMFIEobb1AAAJISN1ox5UQQAACbjRREAAGuy0ZoxyRgAYE20qQEAQKJQGQMArCmk6FrNSbTnmGQMALAm2tQAACBRqIwBANZko8qYZAwAsCYb3dpEmxoAAJNRGQMArIk2NQAAJiMZx1+RpCyzJu8B3tQcG6fMDqAH+F3HDr/v2HjA7ACuIKG37rJmDAAAEoU2NQDAmoKKrtWcRJUxyRgAYE20qQEAQKJQGQMArInd1AAAmMxGyZg2NQAAJqMyBgBYk402cJGMAQDWRJsaAAAkCpUxAMCabFQZk4wBANYUUnTrvgl9kHZ0SMYAAGuyUWXMmjEAACajMgYAWBO3NgEAYDLa1AAAIFEMJePVq1drzJgxys7OVnZ2tkpKSvTmm2/GKzYAgJ0FYjCShKFkPGTIEC1btkzNzc06ePCg/uzP/kzf/e539V//9V/xig8AYFfBGIwkYWjNuKysLOLnJUuWaPXq1dq/f79GjhzZ5TV+v19+vz/8s8/n60WYAAD0Xb1eMw4EAtqyZYs6OjpUUlLS7Xlut1tOpzM8XC5Xb6cEANgJberuHT58WP369ZPD4dCPf/xjbdu2TSNGjOj2/NraWrW3t4eHx+OJKmAAgE0EFV0i7mWbetWqVSooKFBGRoaKi4vV1NR0xfM//PBDVVVVadCgQXI4HPr617+u7du3G5rT8K1Nt956q1paWtTe3q5XXnlFFRUV2r17d7cJ2eFwyOFwGJ0GAGB3JtxnvHXrVlVXV2vNmjUqLi5WfX29pk2bpmPHjiknJ+ey8zs7O3XPPfcoJydHr7zyigYPHqz3339fN9xwg6F5U0KhUFRP7ywtLdWwYcP0wgsv9Oh8n88np9OpE5Kyopk4zgrNDqCPOGV2AD3A7zp2+H33fSFJH0tqb29XdnZ2XOb4Mk+0r5ayM6P4nI8l58PGYi0uLtaECRO0cuVKSVIwGJTL5dIjjzyimpqay85fs2aNnnnmGR09elTXXnttr2ON+j7jYDAYsUELAICYiNGasc/nixjd5azOzk41NzertLQ0fCw1NVWlpaVqbGzs8prXXntNJSUlqqqqUm5urkaNGqWlS5cqEDC2YG0oGdfW1mrPnj06ffq0Dh8+rNraWu3atUtz5swxNCkAAFcVo1ubXC5XxEZit9vd5XTnz59XIBBQbm5uxPHc3Fx5vd4urzl58qReeeUVBQIBbd++XQsXLtSzzz6rX/ziF4b+UQ2tGbe1tam8vFwffPCBnE6nxowZo7feekv33HOPoUkBAEgUj8cT0aaO5T6mYDConJwcrV27VmlpaRo/frzOnj2rZ555RnV1dT3+HEPJeP369YYDBQCgV2L0bOovnxp5NQMGDFBaWppaW1sjjre2tiovL6/LawYNGqRrr71WaWlp4WO33XabvF6vOjs7lZ6e3qNQeTY1AMCaEnyfcXp6usaPH6+GhobwsWAwqIaGhm6fpzF58mSdOHFCweClrdvHjx/XoEGDepyIJZIxAABh1dXVWrdunTZt2qR3331XDz/8sDo6OlRZWSlJKi8vV21tbfj8hx9+WP/3f/+nefPm6fjx43rjjTe0dOlSVVVVGZqXVygCAKzJhPuMZ82apXPnzmnRokXyer0aN26cduzYEd7UdebMGaWmXqpjXS6X3nrrLS1YsEBjxozR4MGDNW/ePD322GOG5o36PmOjuM/YXrjv1F74ffd9Cb3PeJmUnRHF53wiOWviG2us0KYGAMBktKkBANZkQpvaLCRjAIA1xejWpmRAMgYAWJONkjFrxgAAmIzKGABgTawZAwBgMtrUAAAgUaiMEVfJ8ICFZHhQhZQc32UyxJgMv+9k+B4TwkaVMckYAGBNIUW37pvQ50tGhzY1AAAmozIGAFgTbWoAAExmo1ubaFMDAGAyKmMAgDXRpgYAwGQkYwAATMaaMQAASBQqYwCANdGmBgDAZEFFl1BpUwMAgJ6iMgYAWJONNnCRjAEA1mSjNWPa1AAAmIzKGABgTbSpAQAwGW3qnlm2bJlSUlI0f/78GIUDAID99LoyPnDggF544QWNGTMmlvEAAPA5KuMru3DhgubMmaN169apf//+VzzX7/fL5/NFDAAArioYg5EkepWMq6qqNGPGDJWWll71XLfbLafTGR4ul6s3UwIA7ObLJ3D1dvTlZLxlyxYdOnRIbre7R+fX1taqvb09PDwej+EgAQDoywytGXs8Hs2bN087d+5URkZGj65xOBxyOBy9Cg4AYGMBRbfNOInWjA0l4+bmZrW1temOO+4IHwsEAtqzZ49Wrlwpv9+vtLS0mAcJALAh7jPu2re+9S0dPnw44lhlZaWGDx+uxx57jEQMAEAvGErGWVlZGjVqVMSx66+/XjfddNNlxwEAiAptagAATEabuud27doVgzAAALAvKmMAgDXRpgYAwGQ2Ssa8zxgAAJNRGQMArCmk6DZhhWIVSPyRjAEA1hSQlBLl9UmCZAwAsCYbJWPWjAEAMBmVMQDAmnjoBwAAJqNNDQAAEoXKGABgTbSpAQAwmY3a1KYl44H3S9nXmjV7D7xkdgBIlEKzA+ihjjlmR3B11yfBn5tk+X3DXqiMAQDWFFR01S1tagAAohRUdG3qJErG7KYGAMBkVMYAAGuKdgMWG7gAAIgSyRgAAJOxZgwAABKFyhgAYE20qQEAMBltagAAkChUxgAAa4q2sk2iyphkDACwpoCkUBTXJ1Eypk0NAIDJSMYAAGsKxmD0wqpVq1RQUKCMjAwVFxerqampR9dt2bJFKSkpmjlzpuE5ScYAAGsKxGAYtHXrVlVXV6uurk6HDh3S2LFjNW3aNLW1tV3xutOnT+snP/mJ7rzzTuOTimQMAOjjfD5fxPD7/d2eu2LFCj300EOqrKzUiBEjtGbNGl133XV68cUXu70mEAhozpw5euqpp3TLLbf0KkaSMQDAmmJUGbtcLjmdzvBwu91dTtfZ2anm5maVlpaGj6Wmpqq0tFSNjY3dhvnzn/9cOTk5+uEPf9jrf1RDu6mffPJJPfXUUxHHbr31Vh09erTXAQAA0KUY3drk8XiUnZ0dPuxwOLo8/fz58woEAsrNzY04npub222e27t3r9avX6+WlpaoQjV8a9PIkSP19ttvX/qAa7g7CgAQB0FFd2vTF9dmZ2dHJONY+eijj/SDH/xA69at04ABA6L6LMOZ9JprrlFeXl5UkwIAYDUDBgxQWlqaWltbI463trZ2mffee+89nT59WmVlZeFjweDn5fg111yjY8eOadiwYT2a2/Ca8e9//3vl5+frlltu0Zw5c3TmzJkrnu/3+y9bPAcA4KoSfGtTenq6xo8fr4aGhkshBINqaGhQSUnJZecPHz5chw8fVktLS3h85zvf0ZQpU9TS0iKXy9XjuQ1VxsXFxdq4caNuvfVWffDBB3rqqad055136siRI8rKyuryGrfbfdk6MwAAVxVQdC+K6EWLu7q6WhUVFSoqKtLEiRNVX1+vjo4OVVZWSpLKy8s1ePBgud1uZWRkaNSoURHX33DDDZJ02fGrMZSMp0+fHv7fY8aMUXFxsW6++Wa9/PLL3e4iq62tVXV1dfhnn89n6L8WAABIlFmzZuncuXNatGiRvF6vxo0bpx07doQ3dZ05c0apqbG/ESmq3Vc33HCDvv71r+vEiRPdnuNwOLrduQYAQLdMqIwlae7cuZo7d26X/9+uXbuueO3GjRt7NWdU6f3ChQt67733NGjQoGg+BgCAy5n0OEwzGErGP/nJT7R7926dPn1a+/bt05//+Z8rLS1Ns2fPjld8AAD0eYba1P/zP/+j2bNn63//9381cOBAfeMb39D+/fs1cODAeMUHALArk9rUZjCUjLds2RKvOAAAiGSjZMyzqQEAMBnPsgQAWFNISVXdRoNkDACwpF6+kjji+mRBMgYAWJKdkjFrxgAAmIzKGABgSdE+tyOJnvlBMgYAWBNtagAAkDBUxgAAS6JNDQCAyWhTAwCAhKEyBgBYUlDRVbe0qQEAiBJrxglw7mXpE7MmR8KcMjuAHig0O4Aeuv4lsyO4On7fQO9QGQMALMlOG7hIxgAASyIZAwBgMjutGXNrEwAAJqMyBgBYEm1qAABMRpsaAAAkDJUxAMCSeAIXAAAms9OaMW1qAABMRmUMALAkO23gIhkDACyJNjUAAEgYKmMAgCXZqTImGQMALIk1YwAATGanytjwmvHZs2f1/e9/XzfddJMyMzM1evRoHTx4MB6xAQBgC4Yq4z/+8Y+aPHmypkyZojfffFMDBw7U73//e/Xv3z9e8QEAbCqk6FrNoVgFkgCGkvHy5cvlcrm0YcOG8LHCwsKYBwUAAG3qbrz22msqKirSfffdp5ycHN1+++1at27dFa/x+/3y+XwRAwAAXGIoGZ88eVKrV6/W1772Nb311lt6+OGH9eijj2rTpk3dXuN2u+V0OsPD5XJFHTQAoO8LxGAki5RQKNTjtnp6erqKioq0b9++8LFHH31UBw4cUGNjY5fX+P1++f3+8M8+n08ul0snJGX1Pu64o/keG6fMDqAH+F3HDr/vvi8k6WNJ7e3tys7OjsscPp9PTqdTb0i6PorP6ZA0Q/GNNVYMVcaDBg3SiBEjIo7ddtttOnPmTLfXOBwOZWdnRwwAAHCJoQ1ckydP1rFjxyKOHT9+XDfffHNMgwIAwE4buAwl4wULFmjSpElaunSp7r//fjU1NWnt2rVau3ZtvOIDANiUnZKxoTb1hAkTtG3bNv3617/WqFGjtHjxYtXX12vOnDnxig8AgD7P8OMw7733Xt17773xiAUAgDCeTQ0AgMmCiq7VTDIGACBKdqqMDb8oAgAAxBaVMQDAkuy0m5pkDACwJDslY9rUAACYjMoYAGBJdtrARTIGAFgSbWoAAJAwVMYAAEuyU2VMMgYAWFJI0a37hmIVSALQpgYAwGRUxgAAS6JNDcRIodkB9MApswPooWT4LpMhxmT4fSfD95gI3NoEAIDJ7FQZs2YMAIDJqIwBAJZkp8qYZAwAsCQ7rRnTpgYA4CtWrVqlgoICZWRkqLi4WE1NTd2eu27dOt15553q37+/+vfvr9LS0iue3x2SMQDAkgIxGEZt3bpV1dXVqqur06FDhzR27FhNmzZNbW1tXZ6/a9cuzZ49W//+7/+uxsZGuVwuTZ06VWfPnjU0b0ooFEroQ0p8Pp+cTqdOSMpK5MQGcWuBfSTDrS4S/07GSjL8vq38uw5J+lhSe3u7srOz4zLHl3limaSMKD7nE0k1kjweT0SsDodDDoejy2uKi4s1YcIErVy5UpIUDAblcrn0yCOPqKam5qpzBgIB9e/fXytXrlR5eXmPY6UyBgD0aS6XS06nMzzcbneX53V2dqq5uVmlpaXhY6mpqSotLVVjY2OP5rp48aI+/fRT3XjjjYZiZAMXAMCSYrWBq6vKuCvnz59XIBBQbm5uxPHc3FwdPXq0R3M+9thjys/Pj0joPUEyBgBYUqxubcrOzo5bS/2rli1bpi1btmjXrl3KyDDWYCcZAwAgacCAAUpLS1Nra2vE8dbWVuXl5V3x2l/+8pdatmyZ3n77bY0ZM8bw3KwZAwAsKRiDYUR6errGjx+vhoaGSzEEg2poaFBJSUm31z399NNavHixduzYoaKiIoOzfo7KGABgSWY8gau6uloVFRUqKirSxIkTVV9fr46ODlVWVkqSysvLNXjw4PAmsOXLl2vRokXavHmzCgoK5PV6JUn9+vVTv379ejwvyRgAYElmJONZs2bp3LlzWrRokbxer8aNG6cdO3aEN3WdOXNGqamXmsqrV69WZ2en/vIv/zLic+rq6vTkk0/2eF7uM+6Gle/zQ2wlw32nEv9Oxkoy/L6t/LtO5H3Gjyv6+4yXKL6xxgqVMQDAkng2dTcKCgqUkpJy2aiqqopXfAAAmwoqukdhJlMyNlQZHzhwQIHApS78kSNHdM899+i+++6LeWAAANiFoWQ8cODAiJ+XLVumYcOG6a677oppUAAA8D7jHujs7NSvfvUrVVdXKyUlpdvz/H6//H5/+Gefz9fbKQEANsKacQ+8+uqr+vDDD/VXf/VXVzzP7XZHPKDb5XL1dkoAAPqkXifj9evXa/r06crPz7/iebW1tWpvbw8Pj8fT2ykBADZixvuMzdKrNvX777+vt99+W//yL/9y1XOv9N5IAAC6Q5v6KjZs2KCcnBzNmDEj1vEAAGA7hivjYDCoDRs2qKKiQtdcwzNDAADxwW7qK3j77bd15swZPfjgg/GIBwAASSTjK5o6daoS/DhrAIANhRTdum8yZSreZwwAgMlY9AUAWBJtagAATGanZEybGgAAk1EZAwAsyU4P/SAZAwAsiTY1AABIGCpjAIAl0aYGAMBktKkBAEDCUBkDACwpqOiqW9rUAABEiTVjAABMFlB0a6nJtGZsWjI+KOk6sybvgS1mB9BHPGB2AD1QaHYASKhk+H1b+e+fi0qOP9fJhsoYAGBJVMYAAJjMTmvG3NoEAIDJqIwBAJZEmxoAAJPRpgYAAAlDZQwAsCSewAUAgMkCklKivD5Z0KYGAMBkVMYAAEuy0wYukjEAwJLs1KYmGQMALMlOyZg1YwAATEZlDACwJNaMAQAwGW1qAACQMIaScSAQ0MKFC1VYWKjMzEwNGzZMixcvVigUild8AACbCulSq7o3I5kyk6E29fLly7V69Wpt2rRJI0eO1MGDB1VZWSmn06lHH300XjECAGwo2jZzMrWpDSXjffv26bvf/a5mzJghSSooKNCvf/1rNTU1xSU4AADswFCbetKkSWpoaNDx48clSb/73e+0d+9eTZ8+vdtr/H6/fD5fxAAA4GoCMRjJwlBlXFNTI5/Pp+HDhystLU2BQEBLlizRnDlzur3G7XbrqaeeijpQAIC9BBXdbupkurXJUGX88ssv66WXXtLmzZt16NAhbdq0Sb/85S+1adOmbq+pra1Ve3t7eHg8nqiDBgCgLzFUGf/0pz9VTU2NHnjgAUnS6NGj9f7778vtdquioqLLaxwOhxwOR/SRAgBshQ1c3bh48aJSUyOL6bS0NAWDydQMAAAkA5JxN8rKyrRkyRINHTpUI0eO1DvvvKMVK1bowQcfjFd8AACbstOasaFk/Nxzz2nhwoX627/9W7W1tSk/P19/8zd/o0WLFsUrPgAA+jxDyTgrK0v19fWqr6+PUzgAAHwu2sq2z1bGAAAkip2SMS+KAADAZFTGAABLCii6lz0kU2VMMgYAWJKdkjFtagAATEZlDACwJDtt4CIZAwAsiTY1AABIGCpjAIAlBRVdZRzNtYlGMgYAWFK0z6ZOpmRMmxoAYEmBGIzeWLVqlQoKCpSRkaHi4mI1NTVd8fzf/OY3Gj58uDIyMjR69Ght377d8JwkYwAAvrB161ZVV1errq5Ohw4d0tixYzVt2jS1tbV1ef6+ffs0e/Zs/fCHP9Q777yjmTNnaubMmTpy5IiheVNCoVBCK/n29nbdcMMNelHSdYmcGKaoNDsAIAltMDuAK7go6UFJH374oZxOZ1zm8Pl8cjqdylT0beqPJXk8HmVnZ4ePOxwOORyOLq8pLi7WhAkTtHLlSklSMBiUy+XSI488opqamsvOnzVrljo6OvT666+Hj/3pn/6pxo0bpzVr1hgINsE8Hk/oi++IwWAwGEk6PB5P3PLExx9/HMrLy4tJnP369bvsWF1dXZfz+v3+UFpaWmjbtm0Rx8vLy0Pf+c53urzG5XKF/v7v/z7i2KJFi0Jjxowx9M+c8A1c+fn58ng8ysrKUkpKNP/N8zmfzyeXy3XZf/nAGL7H2OB7jB2+y9iI9fcYCoX00UcfKT8/PwbRdS0jI0OnTp1SZ2dn1J8VCoUuyzXdVcXnz59XIBBQbm5uxPHc3FwdPXq0y2u8Xm+X53u9XkNxJjwZp6amasiQITH/3OzsbP7AxgDfY2zwPcYO32VsxPJ7jFd7+qsyMjKUkZER93msgg1cAABIGjBggNLS0tTa2hpxvLW1VXl5eV1ek5eXZ+j87pCMAQCQlJ6ervHjx6uhoSF8LBgMqqGhQSUlJV1eU1JSEnG+JO3cubPb87uT9A/9cDgcqqur63YNAD3D9xgbfI+xw3cZG3yPxlRXV6uiokJFRUWaOHGi6uvr1dHRocrKz+8NKS8v1+DBg+V2uyVJ8+bN01133aVnn31WM2bM0JYtW3Tw4EGtXbvW0LwJv7UJAAArW7lypZ555hl5vV6NGzdO//iP/6ji4mJJ0t13362CggJt3LgxfP5vfvMbPfHEEzp9+rS+9rWv6emnn9a3v/1tQ3OSjAEAMBlrxgAAmIxkDACAyUjGAACYjGQMAIDJkj4ZG33VFSK53W5NmDBBWVlZysnJ0cyZM3Xs2DGzw0p6y5YtU0pKiubPn292KEnn7Nmz+v73v6+bbrpJmZmZGj16tA4ePGh2WEklEAho4cKFKiwsVGZmpoYNG6bFixeL/brWldTJ2OirrnC53bt3q6qqSvv379fOnTv16aefaurUqero6DA7tKR14MABvfDCCxozZozZoSSdP/7xj5o8ebKuvfZavfnmm/rv//5vPfvss+rfv7/ZoSWV5cuXa/Xq1Vq5cqXeffddLV++XE8//bSee+45s0NDN5L61iajr7rC1Z07d045OTnavXu3vvnNb5odTtK5cOGC7rjjDj3//PP6xS9+oXHjxqm+vt7ssJJGTU2N/vM//1P/8R//YXYoSe3ee+9Vbm6u1q9fHz72F3/xF8rMzNSvfvUrEyNDd5K2Mu7s7FRzc7NKS0vDx1JTU1VaWqrGxkYTI0tu7e3tkqQbb7zR5EiSU1VVlWbMmBHx7yV67rXXXlNRUZHuu+8+5eTk6Pbbb9e6devMDivpTJo0SQ0NDTp+/Lgk6Xe/+5327t2r6dOnmxwZupO0j8PszauucGXBYFDz58/X5MmTNWrUKLPDSTpbtmzRoUOHdODAAbNDSVonT57U6tWrVV1drZ/97Gc6cOCAHn30UaWnp6uiosLs8JJGTU2NfD6fhg8frrS0NAUCAS1ZskRz5swxOzR0I2mTMWKvqqpKR44c0d69e80OJel4PB7NmzdPO3futNVr32ItGAyqqKhIS5culSTdfvvtOnLkiNasWUMyNuDll1/WSy+9pM2bN2vkyJFqaWnR/PnzlZ+fz/doUUmbjHvzqit0b+7cuXr99de1Z8+euLxvuq9rbm5WW1ub7rjjjvCxQCCgPXv2aOXKlfL7/UpLSzMxwuQwaNAgjRgxIuLYbbfdpn/+5382KaLk9NOf/lQ1NTV64IEHJEmjR4/W+++/L7fbTTK2qKRdM+7Nq65wuVAopLlz52rbtm36t3/7NxUWFpodUlL61re+pcOHD6ulpSU8ioqKNGfOHLW0tJCIe2jy5MmX3Vp3/Phx3XzzzSZFlJwuXryo1NTIv97T0tIUDAZNighXk7SVsXT1V13h6qqqqrR582b99re/VVZWlrxeryTJ6XQqMzPT5OiSR1ZW1mXr7Ndff71uuukm1t8NWLBggSZNmqSlS5fq/vvvV1NTk9auXWv4dXR2V1ZWpiVLlmjo0KEaOXKk3nnnHa1YsUIPPvig2aGhO6Ek99xzz4WGDh0aSk9PD02cODG0f/9+s0NKKpK6HBs2bDA7tKR31113hebNm2d2GEnnX//1X0OjRo0KORyO0PDhw0Nr1641O6Sk4/P5QvPmzQsNHTo0lJGREbrllltCjz/+eMjv95sdGrqR1PcZAwDQFyTtmjEAAH0FyRgAAJORjAEAMBnJGAAAk5GMAQAwGckYAACTkYwBADAZyRgAAJORjAEAMBnJGAAAk5GMAQAw2f8DdFXPCGw7fqYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 640x480 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(A[0,:,:].cpu().data, cmap='hot', interpolation='nearest')\n",
    "plt.colorbar()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2500],\n",
       "        [0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2500],\n",
       "        [0.3333, 0.0000, 0.0000, 0.3333, 0.0000, 0.0000, 0.3333, 0.0000, 0.0000],\n",
       "        [0.3333, 0.0000, 0.0000, 0.0000, 0.3333, 0.0000, 0.0000, 0.3333, 0.0000],\n",
       "        [0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.0000, 0.0000, 0.0000],\n",
       "        [0.3333, 0.0000, 0.0000, 0.3333, 0.0000, 0.0000, 0.3333, 0.0000, 0.0000],\n",
       "        [0.3333, 0.0000, 0.0000, 0.0000, 0.3333, 0.0000, 0.0000, 0.3333, 0.0000],\n",
       "        [0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2500]],\n",
       "       grad_fn=<SelectBackward0>)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "A[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAIF0lEQVR4nO3csXHcyBZA0Zmt7/2U6DMB+sxAvtJQHEyAPlOijfWujddVwELgOTYEdGOac2sMvee2bdsDAB6Pxz//9QIAuA5RACCiAEBEAYCIAgARBQAiCgBEFADI/4668f+fz6NuDcCC7x3/V9kvBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAyHPbdgzDeDwe7yfMMvo4/An8rd4W/s30PK08Y+roM37Ge7qDO3zWK8w+AmBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAILsH4r0OB+J9LSzmZXj99BnT+1/VyrvlGo4+4yuuuKapK+7hjDVNn/FpIB4AE6IAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYDsnn30Ppx99LG0nJm34fVnrOkOpu91hfOxzx32sOIn7nvl726672+zjwCYEAUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAGT37KPX4eyjr4XFvAyvnz5jev+rWnm3XMPRZ3zFFdc0dcU9nLGm6TM+zT4CYEIUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAOS5bdu258Lfz+foxn+WlsMevw6+/xmf3XQPztN1HH3+Hg+f91G+d3zd+6UAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAkOe2bdueC9+fz9GNP5aWM/M2vP6MNd3B9L2ucD72ucMeVvzEfa/83U33/b3j694vBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgB5btu27bnw9/M5uvGfpeWwx6+D73/GZzfdg/N0HUefv8fD532U7x1f934pABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBAds8+eh/OPlrxcfgT+Fu9Lfyb6XlaecbU0Wf8jPd0B3f4rFeYfQTAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEAOG4h3xjCo6VCrKw6ouqK7DAu7w/m4wx5W/MR9nzHM0EA8AEZEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAOWz20Yo7zC/hGGfMhbnDzKcz3tMd3OGzXmH2EQAjogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKACQ57Zt254LX5/P0Y2/FhbzMrx++ozp/a9q5d1yDUef8RVXXNPUFfdwxpqmz/jc8XXvlwIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAGT37KP34eyjj6XlzLwNrz9jTXcwfa8rnI997rCHFT9x3yt/d9N9f5t9BMCEKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIIfNPlpxh/klHOOMuTB3mPl0xnu6gzt81ivMPgJgRBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACC7B+K9DgfifS0s5mV4/fQZ0/tf1cq75RqOPuMrrrimqSvu4Yw1TZ/xaSAeABOiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGA7J599Hs4++jP0nLY49fB9z/js5vuwXm6jqPP3+Ph8z7Kt9lHAEyIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACCiAECe27Ztey58fz5HN/5YWs7M2/D6M9Z0B9P3usL52OcOe1jxE/e98nc33ff3jq97vxQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAIgoABBRACC7Zx+9DmcffS0s5mV4/fQZ0/tf1cq75RqOPuMrrrimqSvu4Yw1TZ/xafYRABOiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGA7J599D6cfbTi4/An8Ld6W/g30/O08oypo8/4Ge/pDu7wWa/4NvsIgAlRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgIgCABEFACIKAEQUAMhz27Ztz4Wvz+foxl8Li3kZXj99xvT+V7XybrmGo8/4iiuuaeqKezhjTdNnfO74uvdLAYCIAgARBQAiCgBEFACIKAAQUQAgogBARAGAiAIAEQUAsnv20ftw9tHH0nJm3obXn7GmO5i+1xXOxz532MOKn7jvlb+76b6/zT4CYEIUAIgoABBRACCiAEBEAYCIAgARBQAiCgBEFACIKAAQUQAguwfivQ4H4n0tLOZleP30GdP7X9XKu+Uajj7jK664pqkr7uGMNU2f8WkgHgATogBARAGAiAIAEQUAIgoARBQAiCgAEFEAIKIAQEQBgOyefQTA/fmlAEBEAYCIAgARBQAiCgBEFACIKAAQUQAgogBA/gUU3CCpB4vbqgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "one_sample = torch.randint(1, 5, (1,30))\n",
    "one_sample[0,0] = 0\n",
    "y, A = gh(one_sample)\n",
    "plt.imshow(A[0].cpu().data, cmap='hot', interpolation='nearest')\n",
    "#plt.colorbar()\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generating your dataset\n",
    "\n",
    "Now, we will use a 'micro' GPT to learn the task of histograms. Before that, we will use our 'toy' GPT to generate the dataset. Since GPT is equivariant (a permutation of the input will permute the output), we can always take as input a sequence ordered. We can indeed compute all possible different inputs and this number is not too high. For a sequence of lenght `seq_train=s` with at most `nb_digits=n`, there are ${s+n-1 \\choose n-1}$ possibilities. Now for each such sequence, we pass it through our toy GPT to get the label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5456"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_train = 30\n",
    "nb_digits = 4\n",
    "comb = combinations_with_replacement(range(0,seq_train+1), nb_digits-1)\n",
    "\n",
    "def make_seq(c, seq_train):\n",
    "    c_l = [0] + list(c) + [seq_train]\n",
    "    len_seq = len(c_l)-1\n",
    "    return [c_l[i+1]-c_l[i] for i in range(len_seq)]\n",
    "\n",
    "l_comb =  [make_seq(c,seq_train) for c in comb]\n",
    "\n",
    "len(l_comb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 30])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "one_sample.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5456"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "math.comb(seq_train+nb_digits-1, nb_digits-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_inputs(l_comb, nb_digits=nb_digits):\n",
    "    inputs = []\n",
    "    for t in l_comb:\n",
    "        curr = [0]\n",
    "        for (i,j) in enumerate(t):\n",
    "            curr += [i+1 for _ in range(j)]\n",
    "        inputs.append(torch.tensor(np.array(curr)))\n",
    "    return inputs\n",
    "\n",
    "def make_loader(len_seq,nb_digits):\n",
    "    comb = combinations_with_replacement(range(0,len_seq+1), nb_digits-1)\n",
    "    l_comb =  [make_seq(c,len_seq) for c in comb]\n",
    "    inputs = make_inputs(l_comb)\n",
    "    labels = [(gh(d.unsqueeze(0))[0].squeeze(0).squeeze(1)).type(torch.LongTensor) for d in inputs]\n",
    "    dataset = list(zip(inputs,labels))\n",
    "    len_in = len(dataset)\n",
    "    loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)\n",
    "    return loader, len_in, inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, size_train, inputs_train = make_loader(seq_train,nb_digits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5456"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "size_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_in = next(iter(train_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 31])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_in[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 31])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_in[1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3,\n",
       "        3, 3, 3, 3, 3, 3, 3])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_in[0][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0,  2,  2, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,\n",
       "        12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_in[1][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Coding 'micro' GPT\n",
    "\n",
    "Now we need to code the 'micro' GPT used for learning. The game here is to reuse our `SelfAttentionLayer` above without any modification. The only part that is modified is the hard-coded `fake_mlp` which is now replaced by a real MLP."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Block(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.attn = SelfAttentionLayer(config)\n",
    "        self.mlp = nn.Sequential(\n",
    "            nn.Linear(config.n_channels, 4 * config.n_channels),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(4 * config.n_channels, config.n_channels),\n",
    "        )\n",
    "\n",
    "    def forward(self, x, verbose=False): # x (bs, T,ic)\n",
    "        y, A = self.attn(x)# \n",
    "        x =  x + y \n",
    "        x = x + self.mlp(x)# (bs, T, ic)\n",
    "        if verbose:\n",
    "            return x, A\n",
    "        else:\n",
    "            return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPT(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.n_channels = config.n_channels\n",
    "        self.nb_digits = config.nb_digits\n",
    "        self.tok_emb = nn.Embedding(self.nb_digits+1,self.n_channels)\n",
    "        self.block = Block(config)\n",
    "        self.head = nn.Linear(config.n_channels, 1+config.max_hist)\n",
    "        \n",
    "    def forward(self, idx, targets=None, verbose=False):\n",
    "        # shape of idx: (bs, len) 0=bos and 1...nb_digits\n",
    "        # shape of targets: (bs, len)\n",
    "        x= self.tok_emb(idx)\n",
    "        if verbose:\n",
    "            x, A = self.block(x,verbose=verbose)\n",
    "        else:\n",
    "            x = self.block(x) # x: (bs, len, in_channels)\n",
    "        logits = self.head(x) # logits: (bs, len, max_hist)\n",
    "        \n",
    "        loss = None\n",
    "        if targets is not None:\n",
    "            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) # your code here\n",
    "        if verbose:\n",
    "            return logits, loss, A\n",
    "        else:\n",
    "            return logits, loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "class config_gpt:\n",
    "    nb_digits = nb_digits\n",
    "    n_channels = 32 \n",
    "    key_channels = 64 \n",
    "    max_hist = seq_train+1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "gptmini = GPT(config_gpt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "logits, _ = gptmini(batch_in[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 31, 32])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "_,preds = torch.max(logits,-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 31])"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 31])"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_in[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(81)"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.sum(preds == batch_in[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, dataloader, size, epochs=1, optimizer=None):\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        running_loss = 0.0\n",
    "        running_corrects = 0\n",
    "        n_batch = 0\n",
    "        for inputs,targets in dataloader:\n",
    "            inputs = inputs.to(device)\n",
    "            targets = targets.to(device)\n",
    "            logits, loss = model(inputs,targets)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            #\n",
    "            # complete the code below:\n",
    "            _,preds = torch.max(logits,-1)\n",
    "           \n",
    "            running_corrects += torch.true_divide(torch.sum(preds == targets.data),targets.shape[0]*targets.shape[1])\n",
    "            running_loss +=  loss.data.item()\n",
    "            n_batch += 1\n",
    "        epoch_loss = running_loss /n_batch\n",
    "        epoch_acc = running_corrects.data.item() /n_batch\n",
    "        print('Loss: {:.4f} Acc: {:.4f}'.format(\n",
    "                     epoch_loss, epoch_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "gptmini = GPT(config_gpt)\n",
    "gptmini = gptmini.to(device)\n",
    "lr = 0.01\n",
    "optimizer = torch.optim.Adam(gptmini.parameters(),lr = lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: 2.8989 Acc: 0.1404\n",
      "Loss: 2.3290 Acc: 0.2366\n",
      "Loss: 1.8414 Acc: 0.3420\n",
      "Loss: 1.5484 Acc: 0.4130\n",
      "Loss: 1.3893 Acc: 0.4707\n",
      "Loss: 1.1710 Acc: 0.5449\n",
      "Loss: 1.2235 Acc: 0.5548\n",
      "Loss: 1.1447 Acc: 0.5624\n",
      "Loss: 0.9973 Acc: 0.5997\n",
      "Loss: 1.0354 Acc: 0.5942\n",
      "Loss: 0.7673 Acc: 0.7260\n",
      "Loss: 2.0153 Acc: 0.5679\n",
      "Loss: 1.3302 Acc: 0.5225\n",
      "Loss: 1.0100 Acc: 0.6331\n",
      "Loss: 0.9950 Acc: 0.6192\n"
     ]
    }
   ],
   "source": [
    "len_train = (seq_train+1)*size_train\n",
    "train_model(gptmini,train_loader,size_train,15,optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: 1.6273 Acc: 0.5895\n",
      "Loss: 0.8585 Acc: 0.7432\n",
      "Loss: 0.7456 Acc: 0.7976\n",
      "Loss: 0.6564 Acc: 0.8209\n",
      "Loss: 0.6194 Acc: 0.8226\n",
      "Loss: 0.5965 Acc: 0.8187\n",
      "Loss: 0.5400 Acc: 0.8298\n",
      "Loss: 0.4599 Acc: 0.8847\n",
      "Loss: 0.5487 Acc: 0.7952\n",
      "Loss: 0.4360 Acc: 0.8818\n",
      "Loss: 0.4119 Acc: 0.8772\n",
      "Loss: 0.5157 Acc: 0.7874\n",
      "Loss: 0.4393 Acc: 0.8326\n",
      "Loss: 0.3695 Acc: 0.8922\n",
      "Loss: 0.3329 Acc: 0.9038\n"
     ]
    }
   ],
   "source": [
    "lr = 0.005\n",
    "optimizer = torch.optim.Adam(gptmini.parameters(),lr = lr)\n",
    "train_model(gptmini,train_loader,len_train,15,optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: 0.3562 Acc: 0.9074\n",
      "Loss: 0.2381 Acc: 0.9965\n",
      "Loss: 0.2271 Acc: 0.9964\n",
      "Loss: 0.2194 Acc: 0.9964\n",
      "Loss: 0.2115 Acc: 0.9953\n",
      "Loss: 0.1993 Acc: 0.9976\n",
      "Loss: 0.1878 Acc: 0.9975\n",
      "Loss: 0.1770 Acc: 0.9978\n",
      "Loss: 0.1705 Acc: 0.9973\n",
      "Loss: 0.1589 Acc: 0.9982\n",
      "Loss: 0.1484 Acc: 0.9986\n",
      "Loss: 0.1395 Acc: 0.9978\n",
      "Loss: 0.1428 Acc: 0.9915\n",
      "Loss: 0.1259 Acc: 0.9978\n",
      "Loss: 0.1150 Acc: 0.9989\n"
     ]
    }
   ],
   "source": [
    "lr = 0.001\n",
    "optimizer = torch.optim.Adam(gptmini.parameters(),lr = lr)\n",
    "train_model(gptmini,train_loader,len_train,15,optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: 0.1069 Acc: 0.9995\n",
      "Loss: 0.1024 Acc: 0.9998\n",
      "Loss: 0.1009 Acc: 1.0000\n",
      "Loss: 0.0997 Acc: 0.9999\n",
      "Loss: 0.0980 Acc: 0.9999\n",
      "Loss: 0.0966 Acc: 1.0000\n",
      "Loss: 0.0950 Acc: 1.0000\n",
      "Loss: 0.0933 Acc: 1.0000\n",
      "Loss: 0.0919 Acc: 1.0000\n",
      "Loss: 0.0906 Acc: 0.9998\n",
      "Loss: 0.0889 Acc: 1.0000\n",
      "Loss: 0.0872 Acc: 1.0000\n",
      "Loss: 0.0857 Acc: 1.0000\n",
      "Loss: 0.0841 Acc: 1.0000\n",
      "Loss: 0.0826 Acc: 1.0000\n"
     ]
    }
   ],
   "source": [
    "lr = 0.0001\n",
    "optimizer = torch.optim.Adam(gptmini.parameters(),lr = lr)\n",
    "train_model(gptmini,train_loader,len_train,15,optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128, 31, 31])"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "one_batch = batch_in[0].to(device)\n",
    "logits, loss, A = gptmini(one_batch,verbose=True)\n",
    "A.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfQAAAGiCAYAAAARATRgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA4wUlEQVR4nO3dfVBU993//9dCAnjHKkHuFAU01VgFWlTK1WhNpYLX9U01MTNoM1WpYyZGM0lo7mwT0JoZ1OSyNi3VGVujpjXaXL/ETDMtaboNpF5FvcQw5k5GHS0YXby5BlaxLoY9vz9ycchG0F12YTnL8zHzmZGzn3PO+3Q7ee/78/mcc2yGYRgCAACWFhHqAAAAQOBI6AAAhAESOgAAYYCEDgBAGCChAwAQBkjoAACEARI6AABhgIQOAEAYIKEDABAGSOgAAIQBEjoAAH6qqKhQWlqaYmJilJubq0OHDvm03549e2Sz2TR//nyv7UuXLpXNZvNqhYWFfsVEQgcAwA979+5VSUmJysrKdOTIEWVlZamgoEDnz5+/6X6nT5/Wk08+qRkzZnT5eWFhoc6dO2e21157za+4SOgAAPhh06ZNWr58uYqLizVp0iRt3bpVgwcP1vbt27vdp729XQ8++KDWrl2rjIyMLvtER0crKSnJbCNGjPArrtv86t0HPB6Pzp49q2HDhslms4U6HACAnwzD0OXLl5WSkqKIiN6rG69du6a2traAj2MYxg35Jjo6WtHR0Tf0bWtrU21trVavXm1ui4iIUH5+vmpqaro9x89+9jMlJCRo2bJl+vvf/95ln6qqKiUkJGjEiBH67ne/qxdeeEF33HGHz9fR7xL62bNnlZqaGuowAAABamxs1OjRo3vl2NeuXVN6erqcTmfAxxo6dKiuXLnita2srExr1qy5oe/FixfV3t6uxMREr+2JiYk6duxYl8ffv3+/fvvb36qurq7bGAoLC3X//fcrPT1dJ0+e1E9+8hPNnTtXNTU1ioyM9Ok6+l1CHzZsmCRplnwP7vPeCgboY02hDgDoRpwffT+XdFCd/z3vDW1tbXI6nWpsPKXY2NgeH8flcik1NV2NjY1ex+mqOu+Jy5cv64c//KG2bdum+Pj4bvstXLjQ/PeUKVOUmZmpcePGqaqqSrNnz/bpXL2W0CsqKvTiiy/K6XQqKytLv/zlLzV9+vRb7tcx7HGbpNt7Kzign/LtdzjQ93qSLPpi2jQ2NjaghO7vceLj4xUZGammJu+f301NTUpKSrqh/8mTJ3X69Gnde++95jaPxyNJuu2221RfX69x48bdsF9GRobi4+N14sQJnxN6r0xu9HQFIAAA/vk8CM13UVFRysnJkcPhMLd5PB45HA7l5eXd0H/ixIn68MMPVVdXZ7bvf//7uueee1RXV9ftFPOZM2d06dIlJScn+xxbryT0nqwABADAf32b0CWppKRE27Zt086dO/Xpp59qxYoVam1tVXFxsSRp8eLF5qK5mJgYTZ482asNHz5cw4YN0+TJkxUVFaUrV67oqaee0oEDB3T69Gk5HA7NmzdP48ePV0FBgc9xBX3I3d8VgG63W2632/zb5XIFOyQAQNjqWVL23t8/RUVFunDhgkpLS+V0OpWdna3KykpzoVxDQ4Nfq/sjIyN19OhR7dy5U83NzUpJSdGcOXO0bt06v+byg57Q/V0BWF5errVr1wY7DAAAes2qVau0atWqLj+rqqq66b47duzw+nvQoEF65513Ao4p5A+WWb16tVpaWszW2NgY6pAAAJbRrsCG29v7PuReEvQK3d8VgN3dvA8AwK31/ZB7fxX0Ct3fFYAAACBwvXIfeklJiZYsWaKpU6dq+vTp2rx5s9cKQAAAgoMKvUOvJPRbrQD0xf/X0uLzwwKG8Mx3hIlpoQ4A6IY/w7l9uziLhN6h154Ud7MVgAAAILj63bPcAQDwXbsCW6nOKncAAPqBjtvWAtk/PIT8PnQAABA4KnQAgIWxKK4DCR0AYGEk9A4kdACAhZHQOzCHDgBAGKBCBwBYGKvcO5DQAQAWxpB7B4bcAQAIA1ToAAALo0LvQEIHAFgYCb0DQ+4AAIQBKnQAgIVRoXcgoQMALIzb1jow5A4AQBigQgcAWBhD7h1I6AAACyOhdyChAwAsjITegTl0AADCABU6AMDCqNA7kNABABbGbWsdGHIHACAMUKEDACysXYFV2VToAAD0A58HofmvoqJCaWlpiomJUW5urg4dOuTTfnv27JHNZtP8+fO9thuGodLSUiUnJ2vQoEHKz8/X8ePH/YqJhA4AgB/27t2rkpISlZWV6ciRI8rKylJBQYHOnz9/0/1Onz6tJ598UjNmzLjhs40bN+rll1/W1q1bdfDgQQ0ZMkQFBQW6du2az3GR0AEAFtb3FfqmTZu0fPlyFRcXa9KkSdq6dasGDx6s7du3d7tPe3u7HnzwQa1du1YZGRlenxmGoc2bN+u5557TvHnzlJmZqV27duns2bPat2+fz3GR0AEAFtaxyr2n7Ys5dJfL5dXcbneXZ2tra1Ntba3y8/PNbREREcrPz1dNTU23Uf7sZz9TQkKCli1bdsNnp06dktPp9Dqm3W5Xbm7uTY/5VSR0AMCAl5qaKrvdbrby8vIu+128eFHt7e1KTEz02p6YmCin09nlPvv379dvf/tbbdu2rcvPO/bz55hdYZU7AMDCgvNgmcbGRsXGxppbo6OjAwvr/1y+fFk//OEPtW3bNsXHxwflmN0hoQMALCw4CT02NtYroXcnPj5ekZGRampq8tre1NSkpKSkG/qfPHlSp0+f1r333mtu83g8kqTbbrtN9fX15n5NTU1KTk72OmZ2drbPV8KQOwDAwvp2UVxUVJRycnLkcDjMbR6PRw6HQ3l5eTf0nzhxoj788EPV1dWZ7fvf/77uuece1dXVKTU1Venp6UpKSvI6psvl0sGDB7s8Zneo0AEA8ENJSYmWLFmiqVOnavr06dq8ebNaW1tVXFwsSVq8eLFGjRql8vJyxcTEaPLkyV77Dx8+XJK8tj/++ON64YUXdOeddyo9PV3PP/+8UlJSbrhf/WZI6AAAC+v7l7MUFRXpwoULKi0tldPpVHZ2tiorK81FbQ0NDYqI8G8A/Omnn1Zra6seeughNTc36+6771ZlZaViYmJ8PobNMAzDr7P2MpfLJbvdrpaWFp/mMyRpiM3Wy1EBfWNaqAMAuuFPevpc0t8lv/477q/OXPG0YmN7voDN5XLLbt/Yq7H2FebQAQAIAwy5AwAs7HNJkQHuHx5I6AAACyOhd2DIHQCAMECFDgCwMCr0DiR0AICFdbycJZD9wwND7gAAhAEqdACAhX2uwGpThtwBAOgHSOgdSOgAAAsjoXdgDh0AgDBAhQ4AsLB2BbZSPXxWuZPQAQAWxm1rHYI+5L5mzRrZbDavNnHixGCfBgAAfEmvVOhf//rX9de//rXzJLcxEAAA6A2fSwrkFdrhsyiuVzLtbbfdpqSkpN44NAAAX0JC79Arq9yPHz+ulJQUZWRk6MEHH1RDQ0O3fd1ut1wul1cDAAD+CXpCz83N1Y4dO1RZWaktW7bo1KlTmjFjhi5fvtxl//LyctntdrOlpqYGOyQAQNj6PAgtPNgMwzB68wTNzc0aO3asNm3apGXLlt3wudvtltvtNv92uVxKTU1VS0uLYmNjfTrHEFsgwy1A/zEt1AEA3fCn+vtc0t8lv/477i+XyyW73a6Wlu8pNvb2AI5zXXb7u70aa1/p9dVqw4cP19e+9jWdOHGiy8+jo6MVHR3d22EAABDWev1JcVeuXNHJkyeVnJzc26cCAAw4Hfeh97RxH3q3nnzySVVXV+v06dP6xz/+ofvuu0+RkZFatGhRsE8FABjwmEPvEPQh9zNnzmjRokW6dOmSRo4cqbvvvlsHDhzQyJEjg30qAMCAF2hCJqF3a8+ePcE+JAAAuAUe4QYAsDAq9A4kdACAhQW6qI1FcQAAoB+hQgcAWNjnkgJ5Plr4VOgkdACAhZHQOzDkDgBAGCChAwAsLDQPlqmoqFBaWppiYmKUm5urQ4cOddv3jTfe0NSpUzV8+HANGTJE2dnZevXVV736LF26VDabzasVFhb6FRND7gAAC+v7Ife9e/eqpKREW7duVW5urjZv3qyCggLV19crISHhhv5xcXH66U9/qokTJyoqKkpvv/22iouLlZCQoIKCArNfYWGhXnnlFfNvf99zQoUOAIAfNm3apOXLl6u4uFiTJk3S1q1bNXjwYG3fvr3L/rNmzdJ9992nu+66S+PGjdNjjz2mzMxM7d+/36tfdHS0kpKSzDZixAi/4iKhAwAsLDgvZ3G5XF7ty6/1/rK2tjbV1tYqPz/f3BYREaH8/HzV1NTcMlrDMORwOFRfX6+ZM2d6fVZVVaWEhARNmDBBK1as0KVLl/z434GEDgCwtPYgNCk1NVV2u91s5eXlXZ7t4sWLam9vV2Jiotf2xMREOZ3ObqNsaWnR0KFDFRUVpf/4j//QL3/5S33ve98zPy8sLNSuXbvkcDi0YcMGVVdXa+7cuWpv931KgDl0AICFfa7AalOPJKmxsVGxsbHmVn/nr29l2LBhqqur05UrV+RwOFRSUqKMjAzNmjVLkrRw4UKz75QpU5SZmalx48apqqpKs2fP9ukcJHQAwIAXGxvrldC7Ex8fr8jISDU1NXltb2pqUlJSUrf7RUREaPz48ZKk7OxsffrppyovLzcT+ldlZGQoPj5eJ06c8DmhM+QOALCwvr1tLSoqSjk5OXI4HOY2j8cjh8OhvLw8n4/j8Xi6naeXvngV+aVLl5ScnOzzManQAQAWFpwhd3+UlJRoyZIlmjp1qqZPn67NmzertbVVxcXFkqTFixdr1KhR5jx8eXm5pk6dqnHjxsntdutPf/qTXn31VW3ZskWSdOXKFa1du1YLFixQUlKSTp48qaefflrjx4/3uq3tVkjoAAD4oaioSBcuXFBpaamcTqeys7NVWVlpLpRraGhQRETnj4zW1lY98sgjOnPmjAYNGqSJEyfqd7/7nYqKiiRJkZGROnr0qHbu3Knm5malpKRozpw5WrdunV9z+TbDMAK5Iz/oXC6X7Ha7WlpafJrPkKQhNlsvRwX0jWmhDgDohj818OeS/i759d9xf3XmiljFxvY8B7hchux2V6/G2leo0AEAFva5pECKun5V0waERXEAAIQBKnQAgIVRoXcgoQMALIyE3oEhdwAAwgAVOgDAugxPYEV2+BToJHQAgIV51JNnw3jvHyZI6AAA6+p8YVrP9w8TzKEDABAGqNABANZFhW4ioQMArIs5dBND7gAAhAEqdACAdTHkbiKhAwCsiyF3E0PuAACEASp0AIB1eRTYsHkYVegkdACAdTGHbmLIHQCAMECFDgCwLhbFmUjoAADrYsjdREIHAFgXCd3EHDoAAGGACh0AYF3MoZtI6AAA62LI3cSQOwAAYYAKHQBgXYYCGzY3ghVI6JHQAQDWxZC7iSF3AADCABU6AMC6qNBNJHQAgHVx25qJIXcAAPxUUVGhtLQ0xcTEKDc3V4cOHeq27xtvvKGpU6dq+PDhGjJkiLKzs/Xqq6969TEMQ6WlpUpOTtagQYOUn5+v48eP+xUTCR0AYF3tQWh+2rt3r0pKSlRWVqYjR44oKytLBQUFOn/+fJf94+Li9NOf/lQ1NTU6evSoiouLVVxcrHfeecfss3HjRr388svaunWrDh48qCFDhqigoEDXrl3zOS6/E/r777+ve++9VykpKbLZbNq3b5/X58H4lQEAgE9CkNA3bdqk5cuXq7i4WJMmTdLWrVs1ePBgbd++vcv+s2bN0n333ae77rpL48aN02OPPabMzEzt379f0hd5c/PmzXruuec0b948ZWZmateuXTp79uwNOfZm/E7ora2tysrKUkVFRZefB+NXBgAAPvEEoUlyuVxeze12d3m6trY21dbWKj8/39wWERGh/Px81dTU3DJcwzDkcDhUX1+vmTNnSpJOnTolp9PpdUy73a7c3FyfjtnB70Vxc+fO1dy5c7sN9Mu/MiRp165dSkxM1L59+7Rw4UJ/TwcAQK9LTU31+rusrExr1qy5od/FixfV3t6uxMREr+2JiYk6duxYt8dvaWnRqFGj5Ha7FRkZqV//+tf63ve+J0lyOp3mMb56zI7PfBHUVe63+pXRVUJ3u91ev4RcLlcwQwIAhDOPArv17P8q9MbGRsXGxpqbo6OjAwrrq4YNG6a6ujpduXJFDodDJSUlysjI0KxZs4J2jqAm9J78yigvL9fatWuDGQYAYKAI0m1rsbGxXgm9O/Hx8YqMjFRTU5PX9qamJiUlJXW7X0REhMaPHy9Jys7O1qeffqry8nLNmjXL3K+pqUnJyclex8zOzvb5UkK+yn316tVqaWkxW2NjY6hDAgCgS1FRUcrJyZHD4TC3eTweORwO5eXl+Xwcj8djjk6np6crKSnJ65gul0sHDx7065hBrdB78isjOjo66EMbAIABIgRPiispKdGSJUs0depUTZ8+XZs3b1Zra6uKi4slSYsXL9aoUaNUXl4u6YuR6KlTp2rcuHFyu93605/+pFdffVVbtmyRJNlsNj3++ON64YUXdOeddyo9PV3PP/+8UlJSNH/+fJ/jCmpC//KvjI4E3vErY8WKFcE8FQAAIUnoRUVFunDhgkpLS+V0OpWdna3KykpzurmhoUEREZ0D4K2trXrkkUd05swZDRo0SBMnTtTvfvc7FRUVmX2efvpptba26qGHHlJzc7PuvvtuVVZWKiYmxue4bIZh+PXyuCtXrujEiROSpG984xvatGmT7rnnHsXFxWnMmDHasGGD1q9fr507d5q/Mo4ePapPPvnEp8BcLpfsdrtaWlp8ms+QpCE2mz+XAPRb00IdANANf+ZnP5f0d8mv/477y8wVb0uxQwI4Tqtk/3+9G2tf8btCP3z4sO655x7z75KSEknSkiVLtGPHjqD8ygAAwCc8y93kd4Xe26jQMZBRoaO/6rcV+ptBqNDvC48KPeSr3AEAQOB4fSoAwLp4H7qJhA4AsC5Dgc2D96tJ58CQ0AEA1kWFbmIOHQCAMECFDgCwLm5bM5HQAQDWxZC7iSF3AADCABU6AMC6qNBNJHQAgHUxh25iyB0AgDBAhQ4AsC6G3E0kdACAdXkUWFIOoyF3EjoAwLqYQzcxhw4AQBigQgcAWBdz6CYSOgDAuhhyNzHkDgBAGKBCBwBYF0PuJhI6AMC6SOgmhtwBAAgDVOgAAOtiUZyJhA4AsC6eFGdiyB0AgDBAhQ4AsC6G3E0kdACAdbHK3URCBwBYFwndxBw6AABhgIQOALAuTxBaD1RUVCgtLU0xMTHKzc3VoUOHuu27bds2zZgxQyNGjNCIESOUn59/Q/+lS5fKZrN5tcLCQr9iIqEDAKyrPQjNT3v37lVJSYnKysp05MgRZWVlqaCgQOfPn++yf1VVlRYtWqT33ntPNTU1Sk1N1Zw5c/TZZ5959SssLNS5c+fM9tprr/kVFwkdAAA/bNq0ScuXL1dxcbEmTZqkrVu3avDgwdq+fXuX/X//+9/rkUceUXZ2tiZOnKjf/OY38ng8cjgcXv2io6OVlJRkthEjRvgVFwkdAGBdQarQXS6XV3O73V2erq2tTbW1tcrPzze3RUREKD8/XzU1NT6FfPXqVV2/fl1xcXFe26uqqpSQkKAJEyZoxYoVunTpkm//G3TE4VdvAAD6E0OBzZ8bXxwmNTVVdrvdbOXl5V2e7uLFi2pvb1diYqLX9sTERDmdTp9CfuaZZ5SSkuL1o6CwsFC7du2Sw+HQhg0bVF1drblz56q93fc5AW5bAwAMeI2NjYqNjTX/jo6O7pXzrF+/Xnv27FFVVZViYmLM7QsXLjT/PWXKFGVmZmrcuHGqqqrS7NmzfTo2FToAwLqCNOQeGxvr1bpL6PHx8YqMjFRTU5PX9qamJiUlJd001Jdeeknr16/XX/7yF2VmZt60b0ZGhuLj43XixImb9vsyEjoAwLr6+La1qKgo5eTkeC1o61jglpeX1+1+Gzdu1Lp161RZWampU6fe8jxnzpzRpUuXlJyc7HNsJHQAAPxQUlKibdu2aefOnfr000+1YsUKtba2qri4WJK0ePFirV692uy/YcMGPf/889q+fbvS0tLkdDrldDp15coVSdKVK1f01FNP6cCBAzp9+rQcDofmzZun8ePHq6CgwOe4mEMHAFhXCB79WlRUpAsXLqi0tFROp1PZ2dmqrKw0F8o1NDQoIqKzXt6yZYva2tr0wAMPeB2nrKxMa9asUWRkpI4ePaqdO3equblZKSkpmjNnjtatW+fXXL7NMAzD/8vpPS6XS3a7XS0tLV4LFG5miM3Wy1EBfWNaqAMAuuHPcO7nkv4u+fXfcX+ZuWKVFBvA+jWXW7L/qndj7StU6AAA6+L1qSbm0AEACANU6AAA6+L1qSYSOgDAujwKLCkz5A4AAPoTKnQAgHWxKM5EQgcAWBdz6CaG3AEACANU6AAA62LI3URCBwBYF0PuJr+H3N9//33de++9SklJkc1m0759+7w+X7p0qWw2m1crLCwMVrwAAKALflfora2tysrK0o9+9CPdf//9XfYpLCzUK6+8Yv7dWy+KBwAMcFToJr8T+ty5czV37tyb9omOjr7li947uN1uud1u82+Xy+VvSACAgYo5dFOvrHKvqqpSQkKCJkyYoBUrVujSpUvd9i0vL5fdbjdbampqb4QEAAhHHU+K62kjoXevsLBQu3btksPh0IYNG1RdXa25c+eqvb3rcY3Vq1erpaXFbI2NjcEOCQCAsBf0Ve4LFy40/z1lyhRlZmZq3Lhxqqqq0uzZs2/oHx0dzRw7AKBn2hVYaRpGc+i9/mCZjIwMxcfH68SJE719KgDAQOMJQgsTvZ7Qz5w5o0uXLik5Obm3TwUAwIDl95D7lStXvKrtU6dOqa6uTnFxcYqLi9PatWu1YMECJSUl6eTJk3r66ac1fvx4FRQUBDVwAAAYcu/kd0I/fPiw7rnnHvPvkpISSdKSJUu0ZcsWHT16VDt37lRzc7NSUlI0Z84crVu3zu958vF2u8/f0Xf8OjLQf/0z1AEAQdCnOZLb1kx+J/RZs2bJMIxuP3/nnXcCCggAAPiPZ7kDAKyLIXcTCR0AYF0kdBPvQwcAIAxQoQMArMtQYAvbul8SZjkkdACAdbVLsgW4f5ggoQMArIuEbmIOHQCAMECFDgCwLh4sYyKhAwCsiyF3E0PuAAD4qaKiQmlpaYqJiVFubq4OHTrUbd9t27ZpxowZGjFihEaMGKH8/Pwb+huGodLSUiUnJ2vQoEHKz8/X8ePH/YqJhA4AsK4QvD517969KikpUVlZmY4cOaKsrCwVFBTo/PnzXfavqqrSokWL9N5776mmpkapqamaM2eOPvvsM7PPxo0b9fLLL2vr1q06ePCghgwZooKCAl27ds3nuGzGzR7MHgIul0t2u10j5fuvjW/2ZkBAH+LlLAgH7ZLqJbW0tCg2NrZXztGRK1pypNgAJo9dn0v2Wv9izc3N1bRp0/SrX/1KkuTxeJSamqpHH31Uzz777C33b29v14gRI/SrX/1KixcvlmEYSklJ0Y9//GM9+eSTkr6IJzExUTt27NDChQt9iosKHQAw4LlcLq/mdru77NfW1qba2lrl5+eb2yIiIpSfn6+amhqfznX16lVdv35dcXFxkr54DbnT6fQ6pt1uV25urs/HlEjoAAAr8+iLIYGetv8bck9NTZXdbjdbeXl5l6e7ePGi2tvblZiY6LU9MTFRTqfTp5CfeeYZpaSkmAm8Y79Ajimxyh0AYGUeBbbK/f8SemNjo9eQe3R0dEBhdWf9+vXas2ePqqqqFBMTE9RjU6EDAAa82NhYr9ZdQo+Pj1dkZKSampq8tjc1NSkpKemm53jppZe0fv16/eUvf1FmZqa5vWO/nhzzy0joAADrCmS4vaP5ISoqSjk5OXI4HOY2j8cjh8OhvLy8bvfbuHGj1q1bp8rKSk2dOtXrs/T0dCUlJXkd0+Vy6eDBgzc95lcx5A4AsK5AHwzTg/1LSkq0ZMkSTZ06VdOnT9fmzZvV2tqq4uJiSdLixYs1atQocx5+w4YNKi0t1e7du5WWlmbOiw8dOlRDhw6VzWbT448/rhdeeEF33nmn0tPT9fzzzyslJUXz58/3OS4SOgDAuoI0h+6PoqIiXbhwQaWlpXI6ncrOzlZlZaW5qK2hoUEREZ0D4Fu2bFFbW5seeOABr+OUlZVpzZo1kqSnn35ara2teuihh9Tc3Ky7775blZWVfs2zcx860I9wHzrCQZ/ehz5Bio0M4Djtkr2+d2PtK1ToAADrCsGQe39FQgcAWFcIhtz7K1a5AwAQBqjQAQDWFWiFHUYVOgkdAGBd7ZICWdodRgmdIXcAAMIAFToAwLoYcjeR0AEA1sWQu4khdwAAwgAVOgDAuqjQTSR0AIB1MYduIqEDAKzLo8Aq9H71NpPAMIcOAEAYoEIHAFhXoM9yD6MKnYQOALCudpHQ/w9D7gAAhAEqdACAdVGhm0joAADrYg7dxJA7AABhgAodAGBdDLmbSOgAAOsioZsYcgcAIAxQoQMArMtQWFXZgSChAwAsq/3/WiD7hwsSOgDAskjonZhDBwAgDFChAwAsy6PAXmkeRq9DJ6EDAKyLIfdOfg25l5eXa9q0aRo2bJgSEhI0f/581dfXe/W5du2aVq5cqTvuuENDhw7VggUL1NTUFNSgAQCAN78SenV1tVauXKkDBw7o3Xff1fXr1zVnzhy1traafZ544gn98Y9/1Ouvv67q6mqdPXtW999/f9ADBwDAE4QWLmyGYfT4Dr4LFy4oISFB1dXVmjlzplpaWjRy5Ejt3r1bDzzwgCTp2LFjuuuuu1RTU6Nvfetbtzymy+WS3W7XSPn+a+ObPb0AoJ/5Z6gDAIKgXVK9pJaWFsXGxvbKOTpyxT8lBXIGl6Sx6t1Y+0pAq9xbWlokSXFxcZKk2tpaXb9+Xfn5+WafiRMnasyYMaqpqenyGG63Wy6Xy6sBANCfVVRUKC0tTTExMcrNzdWhQ4e67fvxxx9rwYIFSktLk81m0+bNm2/os2bNGtlsNq82ceJEv2LqcUL3eDx6/PHH9e1vf1uTJ0+WJDmdTkVFRWn48OFefRMTE+V0Ors8Tnl5uex2u9lSU1N7GhIAYIDxqHNhXE9aT4bc9+7dq5KSEpWVlenIkSPKyspSQUGBzp8/32X/q1evKiMjQ+vXr1dSUlK3x/3617+uc+fOmW3//v1+xdXjhL5y5Up99NFH2rNnT08PIUlavXq1WlpazNbY2BjQ8QAAA0ew5tC/OlLsdru7PeemTZu0fPlyFRcXa9KkSdq6dasGDx6s7du3d9l/2rRpevHFF7Vw4UJFR0d3e9zbbrtNSUlJZouPj/fnf4qeJfRVq1bp7bff1nvvvafRo0eb25OSktTW1qbm5mav/k1NTd3+KomOjlZsbKxXAwCgL6WmpnqNFpeXl3fZr62tTbW1tV5TyxEREcrPz+92atlXx48fV0pKijIyMvTggw+qoaHBr/39ug/dMAw9+uijevPNN1VVVaX09HSvz3NycnT77bfL4XBowYIFkqT6+no1NDQoLy/Pr8AAALiVYN2H3tjY6FVQdldJX7x4Ue3t7UpMTPTanpiYqGPHjvU4jtzcXO3YsUMTJkzQuXPntHbtWs2YMUMfffSRhg0b5tMx/EroK1eu1O7du/XWW29p2LBh5ry43W7XoEGDZLfbtWzZMpWUlCguLk6xsbF69NFHlZeX59MKdwAA/BGshB7qEeK5c+ea/87MzFRubq7Gjh2rP/zhD1q2bJlPx/AroW/ZskWSNGvWLK/tr7zyipYuXSpJ+vnPf66IiAgtWLBAbrdbBQUF+vWvf+3PaQAA8ElfP/o1Pj5ekZGRNzww7WZTyz0xfPhwfe1rX9OJEyd83sevOXTDMLpsHclckmJiYlRRUaH//d//VWtrq954442gXiQAAKESFRWlnJwcORwOc5vH45HD4Qjq1PKVK1d08uRJJScn+7wPz3IHAFhWKJ7lXlJSoiVLlmjq1KmaPn26Nm/erNbWVhUXF0uSFi9erFGjRpkL69ra2vTJJ5+Y//7ss89UV1enoUOHavz48ZKkJ598Uvfee6/Gjh2rs2fPqqysTJGRkVq0aJHPcZHQAQCWFYq3rRUVFenChQsqLS2V0+lUdna2KisrzYVyDQ0NiojoHAA/e/asvvGNb5h/v/TSS3rppZf0ne98R1VVVZKkM2fOaNGiRbp06ZJGjhypu+++WwcOHNDIkSN9jiugR7/2Bh79ioGMR78iHPTlo1/rJPm2BrxrlyVlKzwe/UqFDgCwrI4nxQWyf7ggoQMALIv3oXcK6OUsAACgf6BCBwBYVigWxfVXJHQAgGUx5N6JIXcAAMIAFToAwLKo0DuR0AEAlsUceicSOgDAsqjQOzGHDgBAGKBCBwBYlqHAhs371bPPA0RCBwBYFkPunRhyBwAgDFChAwAsiwq9EwkdAGBZ3LbWiSF3AADCABU6AMCyGHLvREIHAFgWCb0TQ+4AAIQBKnQAgGWxKK4TCR0AYFkeBTZsTkIHAKAfoELvxBw6AABhgAodAGBZrHLvREIHAFgWCb0TQ+4AAIQBKnQAgGWxKK4TCR0AYFkMuXdiyB0AgDBAhQ4AsCwq9E5U6AAAyzLUOY/ek2b08LwVFRVKS0tTTEyMcnNzdejQoW77fvzxx1qwYIHS0tJks9m0efPmgI/ZFRI6AAB+2Lt3r0pKSlRWVqYjR44oKytLBQUFOn/+fJf9r169qoyMDK1fv15JSUlBOWZXSOgAAMtqD0Lz16ZNm7R8+XIVFxdr0qRJ2rp1qwYPHqzt27d32X/atGl68cUXtXDhQkVHRwflmF0hoQMALCuQ4fYv3/Lmcrm8mtvt7vJ8bW1tqq2tVX5+vrktIiJC+fn5qqmp6dE1BOuYJHQAgGUFq0JPTU2V3W43W3l5eZfnu3jxotrb25WYmOi1PTExUU6ns0fXEKxjssodADDgNTY2KjY21vy7u6Hx/oyEDgCwrGDdthYbG+uV0LsTHx+vyMhINTU1eW1vamrqdsFbXx2TIXcAgGUFaw7dV1FRUcrJyZHD4eiMweORw+FQXl5ej64hWMekQgcAwA8lJSVasmSJpk6dqunTp2vz5s1qbW1VcXGxJGnx4sUaNWqUOQ/f1tamTz75xPz3Z599prq6Og0dOlTjx4/36Zi+IKEDACwrFE+KKyoq0oULF1RaWiqn06ns7GxVVlaai9oaGhoUEdE5AH727Fl94xvfMP9+6aWX9NJLL+k73/mOqqqqfDqmL2yGYfT0QTm9wuVyyW63a6R8nw/4Zm8GBPShf4Y6ACAI2iXVS2ppafFpXronOnLFekkxARznmqRn1bux9hXm0AEACAMMuQMALIv3oXcioQMALIu3rXViyB0AgDBAhQ4AsCyG3Dv5VaGXl5dr2rRpGjZsmBISEjR//nzV19d79Zk1a5ZsNptXe/jhh4MaNAAAUmjettZf+ZXQq6urtXLlSh04cEDvvvuurl+/rjlz5qi1tdWr3/Lly3Xu3Dmzbdy4MahBAwAgkdC/zK8h98rKSq+/d+zYoYSEBNXW1mrmzJnm9sGDB/f4mbYAAMB/AS2Ka2lpkSTFxcV5bf/973+v+Ph4TZ48WatXr9bVq1e7PYbb7b7hPbQAAPiir5/l3p/1eFGcx+PR448/rm9/+9uaPHmyuf0HP/iBxo4dq5SUFB09elTPPPOM6uvr9cYbb3R5nPLycq1du7anYQAABjCPAhs2D6eE3uNHv65YsUJ//vOftX//fo0ePbrbfn/72980e/ZsnThxQuPGjbvhc7fbLbfbbf7tcrmUmprKo18xIPHoV4SDvnz06zOSAnlzuVvSBoXHo197VKGvWrVKb7/9tt5///2bJnNJys3NlaRuE3p0dLQlXyQPAAg9HizTya+EbhiGHn30Ub355puqqqpSenr6Lfepq6uTJCUnJ/coQAAAusN96J38SugrV67U7t279dZbb2nYsGFyOp2SJLvdrkGDBunkyZPavXu3/v3f/1133HGHjh49qieeeEIzZ85UZmZmr1wAAADwM6Fv2bJF0hcPj/myV155RUuXLlVUVJT++te/mi9mT01N1YIFC/Tcc88FLWAAADow5N7J7yH3m0lNTVV1dXVAAQEA4CuG3DvxchYAAMIAL2cBAFgWQ+6dSOgAAMsioXcioQMALMtQYPPgPXqyWj/FHDoAAGGACh0AYFkMuXcioQMALIuE3okhdwAAwgAVOgDAsniwTCcSOgDAshhy78SQOwAAYYAKHQBgWQy5dyKhAwAsiyH3Tgy5AwAQBkjoAADL8qizSu9J6+mQe0VFhdLS0hQTE6Pc3FwdOnTopv1ff/11TZw4UTExMZoyZYr+9Kc/eX2+dOlS2Ww2r1ZYWOhXTCR0AIBleYLQ/LV3716VlJSorKxMR44cUVZWlgoKCnT+/Pku+//jH//QokWLtGzZMn3wwQeaP3++5s+fr48++sirX2Fhoc6dO2e21157za+4bIZh9Ktn07tcLtntdo2U7782vtmbAQF96J+hDgAIgnZJ9ZJaWloUGxvbK+foyBXzJN0ewHGuS3pL/sWam5uradOm6Ve/+pUkyePxKDU1VY8++qieffbZG/oXFRWptbVVb7/9trntW9/6lrKzs7V161ZJX1Tozc3N2rdvX4+vhQodADDguVwur+Z2u7vs19bWptraWuXn55vbIiIilJ+fr5qami73qamp8eovSQUFBTf0r6qqUkJCgiZMmKAVK1bo0qVLfl0DCR0AYFmBzJ9/eYV8amqq7Ha72crLy7s838WLF9Xe3q7ExESv7YmJiXI6nV3u43Q6b9m/sLBQu3btksPh0IYNG1RdXa25c+eqvd33dfjctgYAsKxg3Yfe2NjoNeQeHR0dSFh+W7hwofnvKVOmKDMzU+PGjVNVVZVmz57t0zGo0AEAA15sbKxX6y6hx8fHKzIyUk1NTV7bm5qalJSU1OU+SUlJfvWXpIyMDMXHx+vEiRM+XwMJHQBgWcEacvdVVFSUcnJy5HA4zG0ej0cOh0N5eXld7pOXl+fVX5LefffdbvtL0pkzZ3Tp0iUlJyf7HBsJHQBgWaG4ba2kpETbtm3Tzp079emnn2rFihVqbW1VcXGxJGnx4sVavXq12f+xxx5TZWWl/vM//1PHjh3TmjVrdPjwYa1atUqSdOXKFT311FM6cOCATp8+LYfDoXnz5mn8+PEqKCjwOS7m0AEA8ENRUZEuXLig0tJSOZ1OZWdnq7Ky0lz41tDQoIiIznr53/7t37R7924999xz+slPfqI777xT+/bt0+TJkyVJkZGROnr0qHbu3Knm5malpKRozpw5WrdunV9z+dyHDvQj3IeOcNCX96HPVmCV6eeSHOrdWPsKFToAwLLaJdkC3D9cMIcOAEAYoEIHAFgW70PvREIHAFgWQ+6dSOgAAMsioXdiDh0AgDBAhQ4AsCzm0DuR0AEAlsWQeyeG3AEACANU6AAAyzIU2LB5v3pUaoBI6AAAywp0yJwhdwAA0K9QoQMALIsKvRMJHQBgWR4Ftso9nG5bY8gdAIAwQIUOALAshtw7kdABAJZFQu9EQgcAWBZz6J2YQwcAIAxQoQMALCvQCjucKnQSOgDAskjoncIioYfTogYA6I8G+9GX/yaHRlgkdADAwNSuwF6wEk4Vul+L4rZs2aLMzEzFxsYqNjZWeXl5+vOf/2x+fu3aNa1cuVJ33HGHhg4dqgULFqipqSnoQQMAIH2R0ANt4cKvhD569GitX79etbW1Onz4sL773e9q3rx5+vjjjyVJTzzxhP74xz/q9ddfV3V1tc6ePav777+/VwIHAACdbIZhBPQ62Li4OL344ot64IEHNHLkSO3evVsPPPCAJOnYsWO66667VFNTo29961s+Hc/lcslut2ukfP+1kdWz0IF+50yoAwC64e8c+geSWlpaFBsb2yvxdOSKBAV2/7VH0nn1bqx9pcf/O7S3t2vPnj1qbW1VXl6eamtrdf36deXn55t9Jk6cqDFjxqimpqbb47jdbrlcLq8GAIAvGHLv5HdC//DDDzV06FBFR0fr4Ycf1ptvvqlJkybJ6XQqKipKw4cP9+qfmJgop9PZ7fHKy8tlt9vNlpqa6vdFAAAw0Pmd0CdMmKC6ujodPHhQK1as0JIlS/TJJ5/0OIDVq1erpaXFbI2NjT0+FgBgYPEosOo8nFa5+33bWlRUlMaPHy9JysnJ0f/8z//oF7/4hYqKitTW1qbm5mavKr2pqUlJSUndHi86OlrR0dH+Rw4AGPACfZZ7QIvI+pmAn+Xu8XjkdruVk5Oj22+/XQ6Hw/ysvr5eDQ0NysvLC/Q0AADcgDn0Tn5V6KtXr9bcuXM1ZswYXb58Wbt371ZVVZXeeecd2e12LVu2TCUlJYqLi1NsbKweffRR5eXl+bzCHQAA9IxfCf38+fNavHixzp07J7vdrszMTL3zzjv63ve+J0n6+c9/roiICC1YsEBut1sFBQX69a9/7VdAHXfR+TOv8blfZwD6r3CqFhBe/Pn/ZkffAO+K9vlcDLl/IeD70IPtzJkzrHQHgDDQ2Nio0aNH98qxr127pvT09JveReWrpKQknTp1SjExMUGILHT6XUL3eDw6e/ashg0bJput83eXy+VSamqqGhsbLX/zvy+43vA1kK5VGljXO5CuVer+eg3D0OXLl5WSkqKIiICXanXr2rVramtrC/g4UVFRlk/mUj98OUtERMRNf9F1PEd+oOB6w9dAulZpYF3vQLpWqevrtdvtvX7emJiYsEjEwdJ7P50AAECfIaEDABAGLJPQo6OjVVZWNmAeQsP1hq+BdK3SwLregXSt0sC73v6u3y2KAwAA/rNMhQ4AALpHQgcAIAyQ0AEACAMkdAAAwgAJHQCAMGCZhF5RUaG0tDTFxMQoNzdXhw4dCnVIvWLNmjWy2WxebeLEiaEOKyjef/993XvvvUpJSZHNZtO+ffu8PjcMQ6WlpUpOTtagQYOUn5+v48ePhybYILjV9S5duvSG77qwsDA0wQaovLxc06ZN07Bhw5SQkKD58+ervr7eq8+1a9e0cuVK3XHHHRo6dKgWLFigpqamEEUcGF+ud9asWTd8vw8//HCIIu65LVu2KDMz03waXF5env785z+bn4fT92p1lkjoe/fuVUlJicrKynTkyBFlZWWpoKBA58+fD3VoveLrX/+6zp07Z7b9+/eHOqSgaG1tVVZWlioqKrr8fOPGjXr55Ze1detWHTx4UEOGDFFBQYGuXbvWx5EGx62uV5IKCwu9vuvXXnutDyMMnurqaq1cuVIHDhzQu+++q+vXr2vOnDlqbW01+zzxxBP64x//qNdff13V1dU6e/as7r///hBG3XO+XK8kLV++3Ov73bhxY4gi7rnRo0dr/fr1qq2t1eHDh/Xd735X8+bN08cffywpvL5XyzMsYPr06cbKlSvNv9vb242UlBSjvLw8hFH1jrKyMiMrKyvUYfQ6Scabb75p/u3xeIykpCTjxRdfNLc1Nzcb0dHRxmuvvRaCCIPrq9drGIaxZMkSY968eSGJp7edP3/ekGRUV1cbhvHFd3n77bcbr7/+utnn008/NSQZNTU1oQozaL56vYZhGN/5zneMxx57LHRB9aIRI0YYv/nNb8L+e7Wafl+ht7W1qba2Vvn5+ea2iIgI5efnq6amJoSR9Z7jx48rJSVFGRkZevDBB9XQ0BDqkHrdqVOn5HQ6vb5nu92u3NzcsP2eJamqqkoJCQmaMGGCVqxYoUuXLoU6pKBoaWmRJMXFxUmSamtrdf36da/vd+LEiRozZkxYfL9fvd4Ov//97xUfH6/Jkydr9erVunr1aijCC5r29nbt2bNHra2tysvLC/vv1Wr63dvWvurixYtqb29XYmKi1/bExEQdO3YsRFH1ntzcXO3YsUMTJkzQuXPntHbtWs2YMUMfffSRhg0bFurwek3HO427+p6D8b7j/qiwsFD333+/0tPTdfLkSf3kJz/R3LlzVVNTo8jIyFCH12Mej0ePP/64vv3tb2vy5MmSvvh+o6KiNHz4cK++4fD9dnW9kvSDH/xAY8eOVUpKio4ePapnnnlG9fX1euONN0IYbc98+OGHysvL07Vr1zR06FC9+eabmjRpkurq6sL2e7Wifp/QB5q5c+ea/87MzFRubq7Gjh2rP/zhD1q2bFkII0OwLVy40Pz3lClTlJmZqXHjxqmqqkqzZ88OYWSBWblypT766KOwWftxK91d70MPPWT+e8qUKUpOTtbs2bN18uRJjRs3rq/DDMiECRNUV1enlpYW/dd//ZeWLFmi6urqUIeFr+j3Q+7x8fGKjIy8YdVkU1OTkpKSQhRV3xk+fLi+9rWv6cSJE6EOpVd1fJcD9XuWpIyMDMXHx1v6u161apXefvttvffeexo9erS5PSkpSW1tbWpubvbqb/Xvt7vr7Upubq4kWfL7jYqK0vjx45WTk6Py8nJlZWXpF7/4Rdh+r1bV7xN6VFSUcnJy5HA4zG0ej0cOh0N5eXkhjKxvXLlyRSdPnlRycnKoQ+lV6enpSkpK8vqeXS6XDh48OCC+Z0k6c+aMLl26ZMnv2jAMrVq1Sm+++ab+9re/KT093evznJwc3X777V7fb319vRoaGiz5/d7qertSV1cnSZb8fr/K4/HI7XaH3fdqeaFeleeLPXv2GNHR0caOHTuMTz75xHjooYeM4cOHG06nM9ShBd2Pf/xjo6qqyjh16pTx3//930Z+fr4RHx9vnD9/PtShBezy5cvGBx98YHzwwQeGJGPTpk3GBx98YPzzn/80DMMw1q9fbwwfPtx46623jKNHjxrz5s0z0tPTjX/9618hjrxnbna9ly9fNp588kmjpqbGOHXqlPHXv/7V+OY3v2nceeedxrVr10Idut9WrFhh2O12o6qqyjh37pzZrl69avZ5+OGHjTFjxhh/+9vfjMOHDxt5eXlGXl5eCKPuuVtd74kTJ4yf/exnxuHDh41Tp04Zb731lpGRkWHMnDkzxJH779lnnzWqq6uNU6dOGUePHjWeffZZw2azGX/5y18Mwwiv79XqLJHQDcMwfvnLXxpjxowxoqKijOnTpxsHDhwIdUi9oqioyEhOTjaioqKMUaNGGUVFRcaJEydCHVZQvPfee4akG9qSJUsMw/ji1rXnn3/eSExMNKKjo43Zs2cb9fX1oQ06ADe73qtXrxpz5swxRo4cadx+++3G2LFjjeXLl1v2R2pX1ynJeOWVV8w+//rXv4xHHnnEGDFihDF48GDjvvvuM86dOxe6oANwq+ttaGgwZs6cacTFxRnR0dHG+PHjjaeeespoaWkJbeA98KMf/cgYO3asERUVZYwcOdKYPXu2mcwNI7y+V6vjfegAAISBfj+HDgAAbo2EDgBAGCChAwAQBkjoAACEARI6AABhgIQOAEAYIKEDABAGSOgAAIQBEjoAAGGAhA4AQBggoQMAEAb+f3rdzQ37XmlBAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 640x480 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "k = 45\n",
    "plt.imshow(A[k,:,:].cpu().data, cmap='hot', interpolation='nearest')\n",
    "plt.colorbar()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAELCAYAAABEYIWnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAiPUlEQVR4nO3deXBV5f3H8e8l+55AiISlCRBkC4hGQgcIEURQAm5FFBUJDKJ1A4tSlFYIoAzKsDQREcciUrQC4lIFWRwcF5iptoJQlSIQrCJF1oYgW/L9/eHk/nJyk3NuPBxu4Hm/Zjo15znP85z7POccPrm5zz0+VVUBAADGahTqAwAAAKFFGAAAwHCEAQAADEcYAADAcIQBAAAMRxgAAMBwhAEAAAxHGAAAwHCEAQAADHfRhYGXXnpJfD6flJaWhvpQPFVaWio+n09eeumlUB/KeZOZmSmFhYWhPgw0UFOnThWfzycHDx4M2TF88MEH4vP55IMPPgjZMVRXWFgomZmZ57RN065DL8awIbrowgB+mX379snUqVNly5YtAWWvvPKKzJs377wcx6ZNm2Tq1Kly9OjR89IfgNB66qmn5M033wzYfj7vBXb3P1MQBiAiP18MRUVFDSIMFBUV1XoD2LFjh7zwwgvn5TgAnB92YaCue8G5Znf/e+GFF2THjh2eH0OoEQbOkfLy8lAfwkUvKipKIiIiQn0YMJiqyk8//RTqw8B5FBERIVFRUaE+DM8ZEwbWrFkjeXl5EhcXJwkJCVJQUCD/+te/LPt88cUXUlhYKG3atJHo6Ghp1qyZjB49Wg4dOmTZr+pvk19++aXcfvvtkpKSIr179xaRn/+eNnjwYPn4448lNzdXoqOjpU2bNvLyyy8HHNPRo0dl/Pjx0qpVK4mKipKsrCyZNWuWVFZWBuxXWFgoSUlJkpycLCNHjgw6LR8+fFgeeeQR6dKli8THx0tiYqJcd911snXrVv8+H3zwgXTv3l1EREaNGiU+n8//eYSrrrpK3n33Xdm7d69/e/W/n506dUqmTJkiWVlZEhUVJa1atZKJEyfKqVOnLMfh8/nkgQcekDfffFOys7MlKipKOnfuLO+9955lXB999FEREWndurW/v6rPf9T2t8rdu3fLLbfcIo0bN5bY2Fj59a9/Le+++65ln6q/4y5fvlyefPJJadmypURHR8vVV18t33zzTVDjiAtH1fWSnJwsSUlJMmrUKDlx4oRln8WLF0u/fv0kLS1NoqKipFOnTvLcc88FtFV1Pa9du1auvPJKiYmJkeeff15ERL777ju58cYbJS4uTtLS0uThhx8OOO+drFmzRvLz8yUhIUESExOle/fu8sorr1j2WbFiheTk5EhMTIykpqbKnXfeKd9//31AW1XXVnR0tGRnZ8sbb7xRa5+VlZUyb9486dy5s0RHR8sll1wi99xzjxw5csSyn6rKjBkzpGXLlhIbGyt9+/YNuGfamT17tvTs2VOaNGkiMTExkpOTIytXrrTs4/P5pLy8XJYsWeK/3gsLCx3vBSIif/nLX/zj0rhxY7ntttvkP//5j6X9q666SrKzs+XLL7+Uvn37SmxsrLRo0UKefvpp/z529z+R2j8zUF5eLhMmTPDfu9u3by+zZ8+Wmg8BDua+12DoRWbx4sUqIrpnzx7/tpdffll9Pp9ee+21WlxcrLNmzdLMzExNTk627Dd79mzNy8vTadOm6aJFi3TcuHEaExOjubm5WllZ6d9vypQpKiLaqVMnveGGG3TBggX67LPPqqpqRkaGtm/fXi+55BJ9/PHHtaSkRK+44gr1+Xy6fft2fxvl5eXatWtXbdKkiT7++OO6cOFCveuuu9Tn8+m4ceP8+1VWVmqfPn20UaNGet9992lxcbH269dPu3btqiKiixcvth2PTz/9VNu2bauTJk3S559/XqdNm6YtWrTQpKQk/f7771VVdf/+/Tpt2jQVER07dqwuXbpUly5dqrt27dJ169Zpt27dNDU11b/9jTfeUFXViooKHTBggMbGxur48eP1+eef1wceeEDDw8P1hhtusByHiOhll12m6enpOn36dJ03b562adNGY2Nj9eDBg6qqunXrVh0+fLiKiM6dO9ff3/Hjx/1jO3LkSH+b+/fv10suuUQTEhJ08uTJOmfOHL3sssu0UaNGumrVKv9+GzduVBHRyy+/XHNycnTu3Lk6depUjY2N1dzcXNvxw4Wj6rq8/PLL9eabb9YFCxbomDFjVER04sSJln27d++uhYWFOnfuXC0uLtYBAwaoiGhJSYllv4yMDM3KytKUlBSdNGmSLly4UDdu3KgnTpzQSy+9VKOjo3XixIk6b948zcnJ8V+XGzdudDzexYsXq8/n0+zsbH3yySf12Wef1TFjxuiIESMs+4iIdu/eXefOnauTJk3SmJgYzczM1CNHjvj3W7t2rTZq1Eizs7N1zpw5OnnyZE1KStLOnTtrRkaGpd8xY8ZoeHi43n333bpw4UL9/e9/r3Fxcdq9e3c9ffq0f78//OEPKiI6aNAgLSkp0dGjR2vz5s01NTXVch3WpWXLlnrfffdpSUmJzpkzR3Nzc1VE9J133vHvs3TpUo2KitK8vDz/9b5p0ybHe8GMGTPU5/PprbfeqgsWLNCioiJNTU0NGJf8/Hxt3ry5tmrVSseNG6cLFizQfv36qYjo6tWrVdX+/qeqOnLkSMsYVlZWar9+/dTn8+mYMWO0pKREhwwZoiKi48ePt4xBMPe9huKiDwNlZWWanJysd999t2W//fv3a1JSkmX7iRMnAtp79dVXVUT0ww8/9G+ruukMHz48YP+MjIyA/Q8cOKBRUVE6YcIE/7bp06drXFyc/vvf/7bUnzRpkoaFhem3336rqqpvvvmmiog+/fTT/n3Onj2reXl5QYWBkydPakVFhWXbnj17NCoqSqdNm+bf9umnn9bZXkFBQcANRfXnC7lRo0b60UcfWbYvXLhQRUQ/+eQT/zYR0cjISP3mm2/827Zu3aoiosXFxf5tzzzzTECYq1IzDIwfP15FxNJ/WVmZtm7dWjMzM/2vuyoMdOzYUU+dOuXfd/78+Soium3btoC+cOGpui5Hjx5t2X7TTTdpkyZNLNtqu9YHDhyobdq0sWyrup7fe+89y/Z58+apiOjy5cv928rLyzUrKyuoMHD06FFNSEjQHj166E8//WQpq/rF4/Tp05qWlqbZ2dmWfd555x0VEX3iiSf827p166bp6el69OhR/7Z169apiFiu3Y8++khFRJctW2bp87333rNsP3DggEZGRmpBQYHlF6HHH39cRSSoMFBzjE+fPq3Z2dnar18/y/a4uLha26vrXlBaWqphYWH65JNPWrZv27ZNw8PDLdvz8/NVRPTll1/2bzt16pQ2a9ZMf/Ob3/i32d3/aoaBqnvyjBkzLPsNHTpUfT6f5R4X7H2vIbjo/0ywfv16OXr0qAwfPlwOHjzo/19YWJj06NFDNm7c6N83JibG/98nT56UgwcPyq9//WsREfnnP/8Z0Pa9995ba5+dOnWSvLw8/89NmzaV9u3by+7du/3bVqxYIXl5eZKSkmI5rv79+0tFRYV8+OGHIiKyevVqCQ8Pl9/+9rf+umFhYfLggw8G9fqjoqKkUaOfp7miokIOHTok8fHx0r59+1pfU32sWLFCOnbsKB06dLC8hn79+omIWMZWRKR///7Stm1b/89du3aVxMREy7jUx+rVqyU3N9f/JxoRkfj4eBk7dqyUlpbKl19+adl/1KhREhkZ6f+5ao5+af9omGpel3l5eXLo0CH53//+599W/Vo/duyYHDx4UPLz82X37t1y7NgxS/3WrVvLwIEDLdtWr14t6enpMnToUP+22NhYGTt2bFDHuH79eikrK5NJkyZJdHS0pczn84mIyGeffSYHDhyQ++67z7JPQUGBdOjQwf/nsB9++EG2bNkiI0eOlKSkJP9+11xzjXTq1MnS9ooVKyQpKUmuueYayzWbk5Mj8fHx/mt2w4YNcvr0aXnwwQf9xyMiMn78+KBen4h1jI8cOSLHjh2TvLw81/edVatWSWVlpQwbNszyGpo1aybt2rULuO/Ex8fLnXfe6f85MjJScnNzXd13wsLC5KGHHrJsnzBhgqiqrFmzxrL9XN/3vBIe6gPw2s6dO0VE/P9A1ZSYmOj/78OHD0tRUZH89a9/lQMHDlj2q3mDEPn5JlGbX/3qVwHbUlJSLH+T27lzp3zxxRfStGnTWtuo6n/v3r2Snp4u8fHxlvL27dvXWq+myspKmT9/vixYsED27NkjFRUV/rImTZoE1UZddu7cKV999ZXja6gSzLjUx969e6VHjx4B2zt27Ogvz87OrrP/lJQUEZFf3D8aJrt5rrreP/nkE5kyZYps3rw54PMEx44ds/yjWtt1vnfvXsnKyrL8QykSeF0eP35cjh8/7v85LCxMmjZtKrt27RIRsZyftfVRW5siIh06dJCPP/7Ysl+7du0C9qsZ+nfu3CnHjh2TtLS0Wvusft+prc2mTZv6x9PJO++8IzNmzJAtW7ZYPktRc8zqa+fOnaKqtb5eEQn4kHHLli0D+kxJSZEvvvjiF/W/d+9ead68uSQkJFi2V7/vVHeu73teuejDQNWH8ZYuXSrNmjULKA8P//8hGDZsmGzatEkeffRR6datm8THx0tlZaVce+21AR/qE7Em3+rCwsJq3a7VPlxSWVkp11xzjUycOLHWfS+99NK6X1Q9PPXUU/LHP/5RRo8eLdOnT5fGjRtLo0aNZPz48bW+pvqorKyULl26yJw5c2otb9WqleXnYMbFS6HuH+eH0zzv2rVLrr76aunQoYPMmTNHWrVqJZGRkbJ69WqZO3duwHVR13UejNmzZ0tRUZH/54yMjJB+IVplZaWkpaXJsmXLai2vK9jX10cffSTXX3+99OnTRxYsWCDp6ekSEREhixcvDviAZH1VVlaKz+eTNWvW1DrXNX9xCvV1H+r+g3XRh4Gqt2fS0tKkf//+de535MgRef/996WoqEieeOIJ//aqdxa8OK7jx4/bHpPIzzeP999/X44fP245yYNd97py5Urp27evvPjii5btR48eldTUVP/Pdmm9rrK2bdvK1q1b5eqrr3ad9oM5jpoyMjJqHYevv/7aXw7U9Le//U1OnTolb7/9tuW3tppvL9vJyMiQ7du3i6paztma5+Ndd91l+TNWVbCoui9t375dsrKy6uyjqs2a72zu2LHDX171/7Xdq2oeT9u2bWXDhg3Sq1cv25BTvc02bdr4t//4449B/Ub7+uuvS3R0tKxdu9ayLG/x4sUB+9Z1zdvdd1RVWrdufc5+aarvfWfDhg1SVlZmeXfgQr/vXPSfGRg4cKAkJibKU089JWfOnAko//HHH0Xk/9NbzbTm1ZftDBs2TDZv3ixr164NKDt69KicPXtWREQGDRokZ8+etSx7qqiokOLi4qD6CQsLC3hNK1asCFiaFBcX5++7pri4uFr/TDJs2DD5/vvva/0ioJ9++ukXffeC3XHUNGjQIPn73/8umzdv9m8rLy+XRYsWSWZmZsDfSwGR2q/1Y8eO1foPVV0GDRok+/btsyyVO3HihCxatMiyX5s2baR///7+//Xq1UtERAYMGCAJCQkyc+ZMOXnypKVO1XFdeeWVkpaWJgsXLrS8zb5mzRr56quvpKCgQERE0tPTpVu3brJkyRLLdbp+/fqAz80MGzZMKioqZPr06QGv6ezZs/7rrn///hIRESHFxcWWcQr2fhgWFiY+n8/yZ8nS0tJav1woLi6uzvuOSOC94Oabb5awsDApKioKuLepasBS8GDU975TUVEhJSUllu1z584Vn88n1113Xb37bwgu+ncGEhMT5bnnnpMRI0bIFVdcIbfddps0bdpUvv32W3n33XelV69eUlJSIomJidKnTx95+umn5cyZM9KiRQtZt26d7Nmzx5PjevTRR+Xtt9+WwYMHS2FhoeTk5Eh5ebls27ZNVq5cKaWlpZKamipDhgyRXr16yaRJk6S0tFQ6deokq1atqvUf59oMHjxYpk2bJqNGjZKePXvKtm3bZNmyZZa0L/Jz2k5OTpaFCxdKQkKCxMXFSY8ePaR169aSk5Mjr732mvzud7+T7t27S3x8vAwZMkRGjBghy5cvl3vvvVc2btwovXr1koqKCvn6669l+fLl/rXZ9ZGTkyMiIpMnT5bbbrtNIiIiZMiQIf6LtbpJkybJq6++Ktddd5089NBD0rhxY1myZIns2bNHXn/9df8HJ4HqBgwYIJGRkTJkyBC555575Pjx4/LCCy9IWlqa/PDDD0G1cffdd0tJSYncdddd8o9//EPS09Nl6dKlEhsbG1T9xMREmTt3rowZM0a6d+/u/76SrVu3yokTJ2TJkiUSEREhs2bNklGjRkl+fr4MHz5c/vvf/8r8+fMlMzNTHn74YX97M2fOlIKCAundu7eMHj1aDh8+LMXFxdK5c2fLZxby8/PlnnvukZkzZ8qWLVtkwIABEhERITt37pQVK1bI/PnzZejQodK0aVN55JFHZObMmTJ48GAZNGiQfP7557JmzRrLO4p1KSgokDlz5si1114rt99+uxw4cECeffZZycrKCvhbfU5OjmzYsEHmzJkjzZs3l9atW0uPHj3qvBe0bdtWZsyYIY899piUlpbKjTfeKAkJCbJnzx554403ZOzYsfLII48ENQ9V7O5/NQ0ZMkT69u0rkydPltLSUrnssstk3bp18tZbb8n48eMtHxa8oJz39Qseq+17BlR/Xl42cOBATUpK0ujoaG3btq0WFhbqZ5995t/nu+++05tuukmTk5M1KSlJb7nlFt23b5+KiE6ZMsW/X9USph9//DGg/4yMDC0oKAjYnp+fr/n5+ZZtZWVl+thjj2lWVpZGRkZqamqq9uzZU2fPnm1Z73vo0CEdMWKEJiYmalJSko4YMUI///zzoJcWTpgwQdPT0zUmJkZ79eqlmzdvrvV43nrrLe3UqZOGh4db2j5+/LjefvvtmpycHLBU6fTp0zpr1izt3LmzRkVFaUpKiubk5GhRUZEeO3bMv5+I6P3331/reNVcVjR9+nRt0aKFNmrUyDKXte27a9cuHTp0qCYnJ2t0dLTm5uZa1jGr/v/SwhUrVli279mzJ6gxxIWhruuytnvC22+/rV27dtXo6GjNzMzUWbNm6Z///OeA/eq6nlVV9+7dq9dff73GxsZqamqqjhs3zr9EL5jvGag6jp49e2pMTIwmJiZqbm6uvvrqq5Z9XnvtNb388ss1KipKGzdurHfccYd+9913AW29/vrr2rFjR42KitJOnTrpqlWrApbFVVm0aJHm5ORoTEyMJiQkaJcuXXTixIm6b98+/z4VFRVaVFTkv3dcddVVun379lqvw9q8+OKL2q5dO42KitIOHTro4sWL/XNU3ddff619+vTRmJiYgGWLdd0Lql5v7969NS4uTuPi4rRDhw56//33644dO/z75Ofna+fOnQOOrbZxqev+V9u+ZWVl+vDDD2vz5s01IiJC27Vrp88884xlGaZq/e57oeZTbWCfYgAAAOcV76MCAGA4wgAAAIYjDAAAYDjCAAAAhiMMAABgOMIAAACGIwwAAGC4oL+BcIbDdzf/aFPm9KDG0zZlkTZlXqr7WWLO3Dyg8wqH8hM2ZV871O1Qz2Opzs0cOvVr17aX82B3brWxKXPi5fn+t4vwa0HiztFzLQDUrdzh3sE7AwAAGI4wAACA4QgDAAAYjjAAAIDhCAMAABiOMAAAgOGCfoRxO4flP3ZL4g47tH2lTdlnDnW94rTEz46bhFXpUF5qU+a0HM5pyZsdN3Po1K9d217Og9251dhFv16e7+svwKWFI10sHVx5Do/DdENtypzG2a6uEzdz6NRvQzw/QjVWTlhaCAAAbBEGAAAwHGEAAADDEQYAADAcYQAAAMMRBgAAMFzQTy10UmZTFuFQ1245nVNdryS4qGs3Fm77tWvbaVmim+NyM4dO/dq17eU82PUbqrEK1fnupf0O5R/blPX2sK5X7I4plNwsW3NT180cOvVr17aX82DXb6jGyu35zjsDAAAYjjAAAIDhCAMAABiOMAAAgOEIAwAAGI4wAACA4QgDAAAYLuhHGDut2rzVl1dn2RkXLYdqrXCGi7rLXNS9w6F8n03ZRoe6fet5LNW5mUOnfu3a9nIe7M4tN+v9vTzfV12EjzD26hG3DfHxtk5MfFRwQ5zDhvroZDdjxSOMAQCALcIAAACGIwwAAGA4wgAAAIYjDAAAYDjCAAAAhgt6aWE7h+VB7VwcRBebsm0u2nWjp4u6Xj7CeJNNmd04ing7ll7NoZfzEKpzy81Yrb4AlxYOdLh38Ajji5ubOXTTdqgeYezV63Fq26nuWpYWAgAAO4QBAAAMRxgAAMBwhAEAAAxHGAAAwHCEAQAADEcYAADAcIQBAAAMRxgAAMBwhAEAAAxHGAAAwHCEAQAADEcYAADAcIQBAAAMRxgAAMBwhAEAAAxHGAAAwHCEAQAADEcYAADAcIQBAAAMRxgAAMBwhAEAAAxHGAAAwHCEAQAADEcYAADAcIQBAAAMRxgAAMBwhAEAAAxHGAAAwHDhwe64U9W2vJ/PV2dZpUPbjW3KTjjU9cpMF3WfcFF3mkN5D5syu3EUETlZz2Opzs0cOvVr17aX82B3brlJyRfi+e6lKxzKP/6FZU7c1HXjIRd1/+Rhv27adsPLebBr28t58Oo1hfJ8550BAAAMRxgAAMBwhAEAAAxHGAAAwHCEAQAADEcYAADAcIQBAAAMRxgAAMBwhAEAAAxHGAAAwHCEAQAADEcYAADAcIQBAAAMRxgAAMBwPlWHZxP72T8g8VZfXp1lZ1y03NuhrlcyXNRd5qLuHQ7l+2zKNjrU7VvPY6nOzRw69WvXtpfzYHduRbjo18vzfVWwl2sDMtLm8eYiIitdtD3Uo3ZDxe71OPFqHN227abvUM1hKMfDjpuxKne4d/DOAAAAhiMMAABgOMIAAACGIwwAAGA4wgAAAIYjDAAAYDjCAAAAhiMMAABgOMIAAACGIwwAAGA4wgAAAIYjDAAAYDjCAAAAhiMMAABguHo8wtheP5vHlFY61L3OpmzNLzoa9z51UfcJF3WnOZT3sCkb6FB3bT2PpTo3c+jUr13bXs6D3bnlJiV7eb5/cAE+wvgxh0cY/+k8Hcf58pCLum7Gwqnfi22cnYRqHhoqHmEMAABsEQYAADAcYQAAAMMRBgAAMBxhAAAAwxEGAAAwHGEAAADDBf09A+0c1gpfYVN22KHtK23KPnOo6xW71+PEyzXqpTZlbRzq7q7foVi4mUOnfu3a9nIe7M6txi769fJ8X38Bfs/ASId7h52V5/A4TDfUpsxpnO3qOnEzh079NsTzI1Rj5YTvGQAAALYIAwAAGI4wAACA4QgDAAAYjjAAAIDhCAMAABiuHo8w/ti29FZfXp1lZ1y03NuhrlcyXNRd5qLuHQ7l+2zKNjrU7VvPY6nOzRw69WvXtpfzYHduRbjo18vzfdVFuLTQq6VnDXHZmRMTl/A1xDlsqEsa3YwVSwsBAIAtwgAAAIYjDAAAYDjCAAAAhiMMAABgOMIAAACGIwwAAGA4HmFcBx5hbMUjjIPHI4yteIRxw8AjjM8PHmEMAAAuSIQBAAAMRxgAAMBwhAEAAAxHGAAAwHCEAQAADEcYAADAcIQBAAAMRxgAAMBwhAEAAAxHGAAAwHCEAQAADEcYAADAcIQBAAAMF36uGiqzKYtwqGv32F6nul5JcFHXbizc9mvXttPjj90cl5s5dOrXrm0v58Gu31CNVajOdy/tdyj/2Kast4d1vWJ3TKHk5vG4buq6mUOnfu3a9nIe7PoN1Vi5Pd95ZwAAAMMRBgAAMBxhAAAAwxEGAAAwHGEAAADDEQYAADAcYQAAAMP5VFWD29V+1eatvrw6y864aDlUa4UzXNRd5qLuHQ7l+2zKNjrU7VvPY6nOzRw69WvXtpfzYHduuVnv7+X5virYy7UBGenz2Za7WZc91KN2Q8Xu9Tjxahzdtu2m71DNYSjHw46bsSp3uHfwzgAAAIYjDAAAYDjCAAAAhiMMAABgOMIAAACGIwwAAGC4oJcWtnNYHnSFTdlhh7avtCn7zKGuV+xejxM3CcvpMcSlNmVtHOrurt+hWLiZQ6d+7dr2ch7szq3GLvr18nxffxEuLbRzIS4PbKjcLEszccnjLxWqsXLC0kIAAGCLMAAAgOEIAwAAGI4wAACA4QgDAAAYjjAAAIDhCAMAABgu/Fw1VGZT5vQ4WLu19W4eJetGgou6dmPhtl+7tp2+o8DNcbmZQ6d+7dr2ch7s+g3VWIXqfPfSfodyN490boiPP7d/2HvouFnD7qaumzl06teubS/nwa7fUI2V2/OddwYAADAcYQAAAMMRBgAAMBxhAAAAwxEGAAAwHGEAAADDBb20cKfD4w/72Tym1GnJm93jYk841PXKTBd1n3BRd5pDeQ+bMqfH7p6s57FU52YOnfq1a9vLebA7t7x8DHVDPN+95PQYarvlUm6Wh4Vqid9DLur+ycN+3bTthpfzYNe2l/Pg1WsK5fnOOwMAABiOMAAAgOEIAwAAGI4wAACA4QgDAAAYjjAAAIDhCAMAABiOMAAAgOEIAwAAGI4wAACA4QgDAAAYjjAAAIDhCAMAABiOMAAAgOF8qg7PJvazf0Dirb68OsvOuGi5t0Ndr2S4qLvMRd07HMr32ZRtdKjbt57HUp2bOXTq165tL+fB7tyKcNGvl+f7qmAv1wZkpM3jzUVEVrpoe6hH7YaK3etx4tU4um3bTd+hmsNQjocdN2NV7nDv4J0BAAAMRxgAAMBwhAEAAAxHGAAAwHCEAQAADEcYAADAcIQBAAAMF/T3DLRzWCvczsVBdLEp2+aiXTd6uqhb5qJugkP5Jpsyu3EU8XYsvZpDL+chVOeWm7FafQF+z8BAh3uHm+9daIjfUWL/jSzmcTOHbtr2ch686tfL830t3zMAAADsEAYAADAcYQAAAMMRBgAAMBxhAAAAwxEGAAAwXPi5ashuSdxhh7p2icTpcbBecbM80E3CcurXbpyd+nVatmjHzRw69WvXtpfzYHduNXbR74V4vnupmUO5V4+wZYmflZtxNvHRynbszq1QjZXb8513BgAAMBxhAAAAwxEGAAAwHGEAAADDEQYAADAcYQAAAMMRBgAAMBxhAAAAwxEGAAAwHGEAAADDEQYAADAcYQAAAMMRBgAAMBxhAAAAw52zRxjbPWo2wqFupYu6XnHzuF83j9116teubbtxdKrrxM0cOvVr17aX82DXb6jGKlTnu5f2O5TbPXq1t4d1vdJQH53s5vG4buq6mUOnfu3a9nIe7PoN1Vi5Pd95ZwAAAMMRBgAAMBxhAAAAwxEGAAAwHGEAAADDEQYAADAcYQAAAMP5VFWD29V+1eatvrw6y864aDlUa4UzXNRd5qLuHQ7l+2zKNjrU7VvPY6nOzRw69WvXtpfzYHduuVnv7+X5virYy7UBGenz2Za7WZc91KN2Q8Xu9Tjxahzdtu2m71DNYSjHw46bsSp3uHfwzgAAAIYjDAAAYDjCAAAAhiMMAABgOMIAAACGIwwAAGC4oJcWtnNYHtTOxUF0sSnb5qJdN3q6qOvlI4w32ZTZjaOIt2Pp1Rx6OQ+hOrfcjNXqC3Bp4UCHewePML64uZlDN22H6hHGXr0ep7ad6q5laSEAALBDGAAAwHCEAQAADEcYAADAcIQBAAAMRxgAAMBwhAEAAAxXj0cYAwCAixHvDAAAYDjCAAAAhiMMAABgOMIAAACGIwwAAGA4wgAAAIYjDAAAYDjCAAAAhiMMAABguP8DpgkVSbNCdM0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 640x480 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, (ax1, ax2) = plt.subplots(1, 2)\n",
    "\n",
    "_,_, A = gptmini(one_sample.to(device), verbose=True)\n",
    "ax1.imshow(A[0].cpu().data, cmap='hot', interpolation='nearest')\n",
    "#plt.colorbar()\n",
    "ax1.axis('off')\n",
    "ax1.set_title('learned attention')\n",
    "_, Ah = gh(one_sample)\n",
    "ax2.imshow(Ah[0].cpu().data, cmap='hot', interpolation='nearest')\n",
    "ax2.axis('off')\n",
    "ax2.set_title('hard-coded attention')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dldiy",
   "language": "python",
   "name": "dldiy"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
